from __future__ import annotations
from dataclasses import dataclass, replace
from datetime import date, datetime, time
from decimal import Decimal
from types import UnionType
from typing import Any, ClassVar, Union, cast, get_args, get_origin, get_type_hints
import msgspec
from loom.core.model.field import ColumnType, Field
from loom.core.model.projection import Projection
from loom.core.model.relation import Relation
from loom.core.model.types import JSON, Boolean, DateTime, Float, Integer, Numeric, String
[docs]
@dataclass(frozen=True, slots=True)
class ColumnFieldInfo:
"""Resolved metadata for a single column field."""
name: str
python_type: type
column_type: ColumnType
field: Field
def _collect_inherited_dict_metadata(cls: type, attr: str) -> dict[str, Any]:
"""Merge dict metadata from the full MRO (base -> subclass)."""
merged: dict[str, Any] = {}
for current in reversed(cls.__mro__):
raw = getattr(current, attr, None)
if isinstance(raw, dict):
merged.update(raw)
return merged
[docs]
def get_column_fields(cls: type) -> dict[str, ColumnFieldInfo]:
"""Extract column fields from a model class."""
declared_columns = _collect_inherited_dict_metadata(cls, "__loom_columns__")
hints = get_type_hints(cls, include_extras=True)
struct_fields = {field.name: field for field in msgspec.structs.fields(cls)}
relations = set(get_relations(cls))
projections = set(get_projections(cls))
result: dict[str, ColumnFieldInfo] = {}
for name, struct_field in struct_fields.items():
if name in relations or name in projections:
continue
annotation = hints.get(name, Any)
if _is_classvar(annotation):
continue
declared = declared_columns.get(name)
if declared is not None:
field = declared.field
column_type = declared.column_type or _infer_column_type(annotation, field=field)
result[name] = ColumnFieldInfo(
name=name,
python_type=_extract_origin_type(annotation),
column_type=column_type,
field=_with_struct_default(field, struct_field.default),
)
continue
metadata = _extract_metadata(annotation)
if metadata:
annotated_column_type: ColumnType | None = None
field = Field()
for entry in metadata:
if isinstance(entry, ColumnType):
annotated_column_type = entry
elif isinstance(entry, Field):
field = entry
if annotated_column_type is not None:
result[name] = ColumnFieldInfo(
name=name,
python_type=_extract_origin_type(annotation),
column_type=annotated_column_type,
field=_with_struct_default(field, struct_field.default),
)
continue
inferred_field = _with_struct_default(Field(), struct_field.default)
result[name] = ColumnFieldInfo(
name=name,
python_type=_extract_origin_type(annotation),
column_type=_infer_column_type(annotation, field=inferred_field),
field=inferred_field,
)
return result
def _with_struct_default(field: Field, struct_default: Any) -> Field:
if field.default is not msgspec.UNSET:
return field
if struct_default is msgspec.NODEFAULT:
return field
return cast(Field, replace(field, default=struct_default)) # type: ignore[redundant-cast]
[docs]
def get_relations(cls: type) -> dict[str, Relation]:
"""Return relations registered by ``LoomStructMeta``."""
return _collect_inherited_dict_metadata(cls, "__loom_relations__")
[docs]
def get_projections(cls: type) -> dict[str, Projection]:
"""Return projections registered by ``LoomStructMeta``."""
return _collect_inherited_dict_metadata(cls, "__loom_projections__")
[docs]
def get_id_attribute(cls: type) -> str:
"""Return the name of the primary key field."""
for name, info in get_column_fields(cls).items():
if info.field.primary_key:
return name
raise ValueError(f"No primary key field found on {cls.__name__}")
[docs]
def get_table_name(cls: type) -> str:
"""Return the ``__tablename__`` declared on the model."""
table = getattr(cls, "__tablename__", None)
if not isinstance(table, str):
raise ValueError(f"{cls.__name__} does not declare __tablename__")
return table
def _extract_metadata(annotation: Any) -> tuple[Any, ...]:
"""Pull metadata entries from ``Annotated[T, ...]``."""
return getattr(annotation, "__metadata__", ())
def _extract_origin_type(annotation: Any) -> type[Any]:
"""Return the base type from ``Annotated[T, ...]``."""
origin = getattr(annotation, "__origin__", None)
if origin is not None:
args = getattr(annotation, "__args__", ())
if args:
value = args[0]
if isinstance(value, type):
return value
return object
raw = _unwrap_optional(annotation)
origin = get_origin(raw)
if origin is not None:
if isinstance(origin, type):
return origin
return object
if isinstance(raw, type):
return raw
return object
def _unwrap_optional(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin in (UnionType, Union):
args = tuple(arg for arg in get_args(annotation) if arg is not type(None))
if len(args) == 1:
return args[0]
return annotation
def _is_classvar(annotation: Any) -> bool:
return get_origin(annotation) is ClassVar
_SCALAR_TYPE_MAP: dict[type, ColumnType] = {
int: Integer,
float: Float,
bool: Boolean,
datetime: DateTime(tz=True),
Decimal: Numeric(),
}
def _infer_column_type(annotation: Any, *, field: Field) -> ColumnType:
base = _unwrap_optional(annotation)
if get_origin(base) in (list, tuple, set, dict):
return JSON
python_type = _extract_origin_type(base)
if python_type is str:
return String(field.length)
if python_type in (date, time):
return String(None)
return _SCALAR_TYPE_MAP.get(python_type, JSON)