from __future__ import annotations
import asyncio
import inspect
import types
import typing
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass
from functools import wraps
from typing import Any, Generic, cast, get_args, get_origin, get_type_hints
import msgspec
from loom.core.cache.abc.backend import CacheBackend
from loom.core.cache.abc.config import CacheConfig
from loom.core.cache.abc.dependency import DependencyResolver
from loom.core.cache.keys import entity_key, list_index_key, stable_hash
from loom.core.logger import get_logger
from loom.core.model.enums import Cardinality
from loom.core.model.introspection import get_projections, get_relations
from loom.core.model.projection import Projection
from loom.core.model.relation import Relation
from loom.core.repository import FilterParams, MutationEvent, PageParams, PageResult, Repository
from loom.core.repository.abc.query import CursorResult, FilterGroup, PaginationMode, QuerySpec
from loom.core.repository.abc.repository import CreateT, IdT, OutputT, UpdateT
def _resolve_loader_model(loader: Any) -> type | None:
"""Return the domain model class from a projection loader, handling callables."""
raw = getattr(loader, "model", None)
if raw is None:
return None
if isinstance(raw, type):
return raw
if callable(raw):
try:
resolved = raw()
return resolved if isinstance(resolved, type) else None
except Exception:
return None
return None
def _list_element_type(annotation: Any) -> type | None:
"""Return T for ``list[T]`` annotations, or ``None`` if not a typed list.
Handles ``list[T] | UnsetType`` produced by ``LoomStructMeta`` when it
widens relation field annotations at class-definition time.
"""
origin = get_origin(annotation)
# Unwrap Union / UnionType (e.g. list[T] | UnsetType) to find the list arm.
if origin in {typing.Union, types.UnionType}:
for arg in get_args(annotation):
result = _list_element_type(arg)
if result is not None:
return result
return None
if origin is list:
args = get_args(annotation)
if args and isinstance(args[0], type):
return args[0]
return None
def _infer_otm_cache_dep(model: type, attr_name: str, rel: Relation) -> str | None:
"""Auto-infer ``entity:fk_col`` cache spec for a ONE_TO_MANY relation.
Resolves the child model type from the field's type annotation and
combines it with the FK column name from ``rel.foreign_key``.
Returns ``None`` when cardinality is not ONE_TO_MANY, when the child
type cannot be resolved, or when it has no ``__tablename__``.
"""
if rel.cardinality is not Cardinality.ONE_TO_MANY:
return None
try:
hints = get_type_hints(model)
except Exception:
return None
child_type = _list_element_type(hints.get(attr_name))
if child_type is None or not hasattr(child_type, "__tablename__"):
return None
fk_col = rel.foreign_key.rsplit(".", 1)[-1]
return f"{child_type.__tablename__}:{fk_col}"
def _infer_projection_cache_dep(model: type, proj: Projection) -> str | None:
"""Auto-infer ``entity:fk_col`` cache spec for a projection loader.
Looks up the ONE_TO_MANY relation on ``model`` whose child type matches
the loader's ``model`` attribute, then reuses that relation's FK column.
Returns ``None`` when the loader has no ``model``, when no matching
ONE_TO_MANY relation is found, or when types are unresolvable.
"""
loader_model = _resolve_loader_model(proj.loader)
if loader_model is None or not hasattr(loader_model, "__tablename__"):
return None
try:
hints = get_type_hints(model)
except Exception:
return None
for attr_name, rel in get_relations(model).items():
if rel.cardinality is not Cardinality.ONE_TO_MANY:
continue
if _list_element_type(hints.get(attr_name)) is loader_model:
fk_col = rel.foreign_key.rsplit(".", 1)[-1]
return f"{loader_model.__tablename__}:{fk_col}"
return None
class _ListIndexPayload(msgspec.Struct):
ids: list[Any]
total_count: int
class _QueryIndexPayload(msgspec.Struct):
ids: list[Any]
total_count: int | None = None
next_cursor: str | None = None
has_next: bool = False
@dataclass(frozen=True, slots=True)
class _DependencySpec:
entity: str
fk_field: str
[docs]
class CachedRepository(
Repository[OutputT, CreateT, UpdateT, IdT],
Generic[OutputT, CreateT, UpdateT, IdT],
):
"""Cache-aside wrapper with generational invalidation."""
def __init__(
self,
repository: Repository[OutputT, CreateT, UpdateT, IdT],
*,
config: CacheConfig,
cache: CacheBackend,
dependency_resolver: DependencyResolver,
) -> None:
self._repository = repository
self._config = config
self._cache = cache
self._resolver = dependency_resolver
fallback_name = repository.__class__.__name__.lower()
self._entity_name = getattr(repository, "entity_name", fallback_name)
self._depends_on = self._parse_dependency_specs(self._collect_dependency_specs(repository))
self._log = get_logger(__name__).bind(repository=repository.__class__.__name__)
@property
def entity_name(self) -> str:
"""Normalized name of the cached entity."""
return self._entity_name
[docs]
async def get_by_id(self, obj_id: IdT, profile: str = "default") -> OutputT | None:
tags = self._resolver.entity_tags(self.entity_name, obj_id)
tags.extend(self._entity_dependency_tags(obj_id))
fingerprint = await self._resolver.fingerprint(tags)
key = entity_key(self.entity_name, obj_id, profile, fingerprint)
cached_payload = await self._cache.get_value(key)
if cached_payload is not None:
self._log.debug("CacheHitEntity", key=key)
return self._to_output_from_cache(cached_payload)
self._log.debug("CacheMissEntity", key=key)
loaded = await self._repository.get_by_id(obj_id, profile=profile)
if loaded is None:
return None
ttl = self._config.ttl_for_single(self.entity_name)
await self._cache.set_value(key, self._to_builtins(loaded), ttl=ttl)
return loaded
[docs]
async def get_by(
self,
field: str,
value: Any,
profile: str = "default",
) -> OutputT | None:
"""Fetch one entity by arbitrary field.
This path intentionally delegates to the wrapped repository without
cache-aside behavior for now. Field-based lookups can target mutable
columns and the cache invalidation surface is broader than id-based
access; keeping it uncached preserves correctness while the lookup
cache policy is designed explicitly.
"""
return await self._repository.get_by(field, value, profile=profile)
[docs]
async def exists_by(self, field: str, value: Any) -> bool:
"""Check existence by arbitrary field.
Existence checks are delegated directly to the wrapped repository to
avoid stale negative/positive cache entries on mutable fields.
"""
return await self._repository.exists_by(field, value)
[docs]
async def list_paginated(
self,
page_params: PageParams,
filter_params: FilterParams | None = None,
profile: str = "default",
) -> PageResult[OutputT]:
filters_payload = self._serialize_filters(filter_params)
filter_fingerprint = stable_hash(filters_payload)
tags = self._resolver.list_tags(self.entity_name, filter_fingerprint)
tags.extend(self._list_dependency_tags())
fingerprint = await self._resolver.fingerprint(tags)
index_key = list_index_key(
self.entity_name,
filter_fingerprint,
page=page_params.page,
limit=page_params.limit,
profile=profile,
deps_fingerprint=fingerprint,
)
cached_index = await self._cache.get_value(index_key, type=_ListIndexPayload)
if cached_index is not None:
entity_ids = cast(list[IdT], cached_index.ids)
total_count = cached_index.total_count
items = await self._load_items_from_index(entity_ids, profile=profile)
if len(items) == len(entity_ids):
self._log.debug("CacheHitList", key=index_key)
return PageResult(
items=tuple(items),
total_count=total_count,
page=page_params.page,
limit=page_params.limit,
has_next=(page_params.offset + len(items)) < total_count,
)
self._log.debug("CacheMissList", key=index_key)
page = await self._repository.list_paginated(
page_params,
filter_params=filter_params,
profile=profile,
)
ids = [
cast(IdT, entity_id)
for item in page.items
for entity_id in [self._extract_entity_id(item)]
if entity_id is not None
]
index_to_store = _ListIndexPayload(ids=ids, total_count=page.total_count)
ttl = self._config.ttl_for_list(self.entity_name)
await self._cache.set_value(index_key, index_to_store, ttl=ttl)
await self._cache_entity_batch(page.items, profile=profile)
return page
[docs]
async def list_with_query(
self,
query: QuerySpec,
profile: str = "default",
) -> PageResult[OutputT] | CursorResult[OutputT]:
is_cursor = query.pagination == PaginationMode.CURSOR
# For cursor pagination, cache only the first page to avoid
# unbounded key growth and low-hit deep-page caches.
should_cache = not is_cursor or query.cursor is None
if not should_cache:
return await self._repository.list_with_query(query, profile=profile)
query_payload = self._serialize_query(query)
query_fingerprint = stable_hash(repr(query_payload))
tags = self._resolver.list_tags(self.entity_name, query_fingerprint)
tags.extend(self._list_dependency_tags())
deps_fingerprint = await self._resolver.fingerprint(tags)
query_key = (
f"{self.entity_name}:query:{query_fingerprint}:"
f"profile={profile}:deps={deps_fingerprint}"
)
cached_index = await self._cache.get_value(query_key, type=_QueryIndexPayload)
if cached_index is not None:
entity_ids = cast(list[IdT], cached_index.ids)
items = await self._load_items_from_index(entity_ids, profile=profile)
if len(items) == len(entity_ids):
self._log.debug("CacheHitQuery", key=query_key)
if is_cursor:
return CursorResult(
items=tuple(items),
next_cursor=cached_index.next_cursor,
has_next=cached_index.has_next,
)
return PageResult(
items=tuple(items),
total_count=0 if cached_index.total_count is None else cached_index.total_count,
page=query.page,
limit=query.limit,
has_next=cached_index.has_next,
)
self._log.debug("CacheMissQuery", key=query_key)
loaded = await self._repository.list_with_query(query, profile=profile)
ids = [
cast(IdT, entity_id)
for item in loaded.items
for entity_id in [self._extract_entity_id(item)]
if entity_id is not None
]
ttl = self._config.ttl_for_list(self.entity_name)
if isinstance(loaded, CursorResult):
await self._cache.set_value(
query_key,
_QueryIndexPayload(
ids=ids,
next_cursor=loaded.next_cursor,
has_next=loaded.has_next,
),
ttl=ttl,
)
await self._cache_entity_batch(loaded.items, profile=profile)
return loaded
await self._cache.set_value(
query_key,
_QueryIndexPayload(
ids=ids,
total_count=loaded.total_count,
has_next=loaded.has_next,
),
ttl=ttl,
)
await self._cache_entity_batch(loaded.items, profile=profile)
return loaded
[docs]
async def create(self, data: CreateT) -> OutputT:
created = await self._repository.create(data)
entity_id = getattr(created, "id", None)
await self._resolver.bump_from_events(
(
MutationEvent(
entity=self.entity_name,
op="create",
ids=() if entity_id is None else (entity_id,),
changed_fields=frozenset(self._struct_keys(data)),
),
)
)
return created
[docs]
async def update(self, obj_id: IdT, data: UpdateT) -> OutputT | None:
updated = await self._repository.update(obj_id, data)
if updated is None:
return None
await self._resolver.bump_from_events(
(
MutationEvent(
entity=self.entity_name,
op="update",
ids=(obj_id,),
changed_fields=frozenset(self._struct_keys(data)),
),
)
)
return updated
[docs]
async def delete(self, obj_id: IdT) -> bool:
deleted = await self._repository.delete(obj_id)
if deleted:
await self._resolver.bump_from_events(
(
MutationEvent(
entity=self.entity_name,
op="delete",
ids=(obj_id,),
),
)
)
return deleted
async def on_transaction_committed(self, events: tuple[MutationEvent, ...]) -> None:
await self._resolver.bump_from_events(events)
post_commit = getattr(self._repository, "on_transaction_committed", None)
if inspect.iscoroutinefunction(post_commit):
handler = cast(Callable[[tuple[MutationEvent, ...]], Awaitable[None]], post_commit)
await handler(events)
def __getattr__(self, name: str) -> Any:
attr = getattr(self._repository, name)
if not callable(attr):
return attr
metadata = getattr(attr, "__cache_query__", None)
if metadata is None:
return attr
if inspect.iscoroutinefunction(attr):
return self._wrap_custom_cached_method(
name,
cast(Callable[..., Awaitable[Any]], attr),
metadata,
)
return attr
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _extract_entity_id(self, item: Any) -> object | None:
value = getattr(item, "id", None)
return value
def _wrap_custom_cached_method(
self,
method_name: str,
method: Callable[..., Awaitable[Any]],
metadata: dict[str, object],
) -> Callable[..., Awaitable[Any]]:
@wraps(method)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
raw = {"args": self._to_builtins(args), "kwargs": self._to_builtins(kwargs)}
raw_hash = stable_hash(repr(raw))
scope = cast(str, metadata.get("scope") or "list")
ttl_key = cast(str | None, metadata.get("ttl_key"))
if scope == "entity":
entity_id: object = args[0] if args else raw_hash
tags = self._resolver.entity_tags(self.entity_name, entity_id)
tags.extend(self._entity_dependency_tags(entity_id))
ttl = self._config.ttl_for_single(ttl_key or self.entity_name)
else:
tags = self._resolver.list_tags(self.entity_name, raw_hash)
tags.extend(self._list_dependency_tags())
ttl = self._config.ttl_for_list(ttl_key or self.entity_name)
fingerprint = await self._resolver.fingerprint(tags)
key = f"{self.entity_name}:custom:{method_name}:{raw_hash}:deps={fingerprint}"
cached_payload = await self._cache.get_value(key)
if cached_payload is not None:
self._log.debug("CacheHitCustomMethod", key=key, method=method_name)
return cached_payload
result = await method(*args, **kwargs)
await self._cache.set_value(key, self._to_builtins(result), ttl=ttl)
self._log.debug("CacheMissCustomMethod", key=key, method=method_name)
return result
return wrapped
async def _load_items_from_index(self, ids: list[IdT], profile: str) -> list[OutputT]:
tags_by_id = [
self._resolver.entity_tags(self.entity_name, eid) + self._entity_dependency_tags(eid)
for eid in ids
]
fingerprints = list(
await asyncio.gather(*(self._resolver.fingerprint(tags) for tags in tags_by_id))
)
entity_keys = [
entity_key(self.entity_name, eid, profile, fp)
for eid, fp in zip(ids, fingerprints, strict=False)
]
cached_values = await self._cache.multi_get_values(entity_keys)
items: list[OutputT] = []
missing_ids: list[IdT] = []
missing_positions: list[int] = []
for index, value in enumerate(cached_values):
if value is None:
missing_ids.append(ids[index])
missing_positions.append(index)
items.append(cast(OutputT, None))
continue
restored = self._to_output_from_cache(value)
if restored is None:
missing_ids.append(ids[index])
missing_positions.append(index)
items.append(cast(OutputT, None))
continue
items.append(restored)
if not missing_ids:
return items
for missing_id, position in zip(missing_ids, missing_positions, strict=False):
loaded = await self._repository.get_by_id(missing_id, profile=profile)
if loaded is None:
return []
items[position] = loaded
refill_pairs: list[tuple[str, Any]] = []
ttl = self._config.ttl_for_single(self.entity_name)
for obj, ek in zip(items, entity_keys, strict=False):
refill_pairs.append((ek, self._to_builtins(obj)))
await self._cache.multi_set_values(refill_pairs, ttl=ttl)
return items
async def _cache_entity_batch(self, items: Sequence[OutputT], profile: str) -> None:
if not items:
return
ttl = self._config.ttl_for_single(self.entity_name)
entity_ids = [getattr(item, "id", None) for item in items]
tags_by_id = [
self._resolver.entity_tags(self.entity_name, entity_id)
+ self._entity_dependency_tags(entity_id)
for entity_id in entity_ids
]
fingerprints = await asyncio.gather(
*(self._resolver.fingerprint(tags) for tags in tags_by_id)
)
pairs: list[tuple[str, Any]] = []
for item, entity_id, fingerprint in zip(items, entity_ids, fingerprints, strict=False):
key = entity_key(self.entity_name, entity_id, profile, fingerprint)
pairs.append((key, self._to_builtins(item)))
await self._cache.multi_set_values(pairs, ttl=ttl)
def _to_output_from_cache(self, payload: Any) -> OutputT | None:
if payload is None:
return None
if isinstance(payload, msgspec.Struct):
return cast(OutputT, payload)
if isinstance(payload, Mapping):
builder = getattr(self._repository, "to_output_from_payload", None)
if callable(builder):
return cast(OutputT, builder(payload))
return cast(OutputT, payload)
def _list_dependency_tags(self) -> list[str]:
tags: list[str] = []
for dependency in self._depends_on:
tags.append(dependency.entity)
tags.append(f"{dependency.entity}:list")
return tags
def _entity_dependency_tags(self, obj_id: object) -> list[str]:
tags: list[str] = []
for dependency in self._depends_on:
tags.append(dependency.entity)
tags.append(f"{dependency.entity}:list")
tags.append(f"{dependency.entity}:{dependency.fk_field}:{obj_id}")
return tags
def _parse_dependency_specs(self, specs: tuple[str, ...]) -> list[_DependencySpec]:
deps: list[_DependencySpec] = []
for spec in specs:
if ":" not in spec:
msg = f"Invalid dependency spec '{spec}'. Expected '<entity>:<fk_field>'"
raise ValueError(msg)
entity_name, fk_field = spec.split(":", 1)
entity = entity_name.strip()
field = fk_field.strip()
if not entity or not field:
raise ValueError(f"Invalid dependency spec '{spec}'. Empty entity or fk field")
deps.append(_DependencySpec(entity=entity, fk_field=field))
return deps
def _collect_dependency_specs(
self,
repository: Repository[OutputT, CreateT, UpdateT, IdT],
) -> tuple[str, ...]:
specs: list[str] = list(getattr(repository, "depends_on", ()))
model = getattr(repository, "model", None)
if model is None:
return tuple(specs)
for attr_name, rel in get_relations(model).items():
if rel.depends_on:
specs.extend(rel.depends_on)
else:
inferred = _infer_otm_cache_dep(model, attr_name, rel)
if inferred is not None:
specs.append(inferred)
for proj in get_projections(model).values():
if proj.depends_on:
specs.extend(proj.depends_on)
else:
inferred = _infer_projection_cache_dep(model, proj)
if inferred is not None:
specs.append(inferred)
return tuple(dict.fromkeys(specs))
def _to_builtins(self, value: Any) -> Any:
if isinstance(value, msgspec.Struct):
return msgspec.to_builtins(value)
if isinstance(value, list | tuple):
return [self._to_builtins(item) for item in value]
if isinstance(value, dict):
return {str(key): self._to_builtins(item) for key, item in value.items()}
return value
def _serialize_filters(self, filter_params: FilterParams | None) -> str:
if filter_params is None:
return "{}"
return repr(self._to_builtins(filter_params.filters))
def _serialize_query(self, query: QuerySpec) -> dict[str, Any]:
return {
"pagination": query.pagination.value,
"limit": query.limit,
"page": query.page,
"cursor": query.cursor,
"sort": [{"field": sort.field, "direction": sort.direction} for sort in query.sort],
"filters": self._serialize_filter_group(query.filters),
}
def _serialize_filter_group(self, group: FilterGroup | None) -> dict[str, Any] | None:
if group is None:
return None
return {
"op": group.op,
"filters": [
{
"field": item.field,
"op": item.op.value,
"value": self._to_builtins(item.value),
}
for item in group.filters
],
}
def _struct_keys(self, payload: msgspec.Struct) -> list[str]:
data = msgspec.to_builtins(payload)
if isinstance(data, dict):
return list(data.keys())
return []