"""Polars target writer implementing _WritePolicy hooks."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Literal
import polars as pl
from deltalake import CommitProperties, DeltaTable, WriterProperties, write_deltalake
from loom.core.logger import get_logger
from loom.etl.backends._historify._transform import scd2_transform
from loom.etl.backends._merge import (
SOURCE_ALIAS,
TARGET_ALIAS,
_build_merge_plan,
_build_partition_predicate,
_log_partition_combos,
_warn_no_partition_cols,
)
from loom.etl.backends._predicate import predicate_to_sql
from loom.etl.backends._write_policy import _WritePolicy
from loom.etl.backends.polars._historify import PolarsHistorifyBackend
from loom.etl.backends.polars._schema import (
PolarsPhysicalSchema,
apply_schema_polars,
read_delta_physical_schema,
)
from loom.etl.declarative.target import SchemaMode
from loom.etl.declarative.target._file import FileSpec
from loom.etl.declarative.target._history import HistorifyRepairReport, HistorifySpec
from loom.etl.declarative.target._table import (
AppendSpec,
ReplacePartitionsSpec,
UpsertSpec,
)
from loom.etl.lineage._records import LineageRecord, WriteContext, get_lineage_schema
from loom.etl.schema._schema import SchemaError
from loom.etl.storage import (
MissingTablePolicy,
PathRouteResolver,
PathTarget,
ResolvedTarget,
TableLocation,
TableRouteResolver,
)
from loom.etl.storage._config import AuditConfig
from loom.etl.storage._file_locator import FileLocator
from loom.etl.storage._locator import TableLocator, _as_locator
from ._dtype import loom_type_to_polars
from ._file_writer import PolarsFileWriter
from ._streaming import partition_rows_from_spool, spool_lazy_to_arrow_reader
_log = get_logger(__name__)
def _log_write_start(
mode: str,
location: TableLocation,
frame: pl.DataFrame,
*,
partition_cols: tuple[str, ...] = (),
predicate: str | None = None,
) -> None:
_log.info(
"delta write start",
mode=mode,
uri=location.uri,
rows=frame.height,
cols=len(frame.columns),
partition_cols=list(partition_cols) or None,
predicate=predicate,
)
def _log_write_commit(mode: str, location: TableLocation) -> None:
try:
dt = DeltaTable(location.uri, storage_options=location.storage_options or {})
history = dt.history(1)
entry = history[0] if history else {}
metrics = entry.get("operationMetrics") or {}
_log.info(
"delta write commit",
mode=mode,
uri=location.uri,
version=dt.version(),
rows=int(metrics.get("numOutputRows", 0)) or None,
bytes=int(metrics.get("numOutputBytes", 0)) or None,
files=int(metrics.get("numAddedFiles", 0)) or None,
)
except Exception: # noqa: BLE001
_log.info("delta write commit", mode=mode, uri=location.uri, metrics="unavailable")
def _check_null_dtype_columns_lazy(frame: pl.LazyFrame) -> None:
"""Lazy variant of :func:`_check_null_dtype_columns`.
Raises :exc:`~loom.etl.schema.SchemaError` if any column in the lazy
schema has dtype ``Null``. Operates on ``collect_schema()`` (metadata
only), so it is safe to call on a frame about to be streamed.
"""
schema = frame.collect_schema()
null_cols = [name for name, dtype in schema.items() if isinstance(dtype, pl.Null)]
if not null_cols:
return
schema_repr = "\n".join(f" {n}: {d}" for n, d in schema.items())
raise SchemaError(
f"Delta Lake does not support the Null dtype. "
f"Column(s) with Null dtype: {null_cols}.\n"
f"Cast each column to its intended type before writing "
f"(use an explicit schema declaration or .cast() in the pipeline step).\n"
f"Full frame schema:\n{schema_repr}"
)
def _check_null_dtype_columns(frame: pl.DataFrame) -> None:
"""Raise :exc:`~loom.etl.schema.SchemaError` if any column has dtype ``Null``.
Delta Lake rejects ``Null``-typed columns. This check surfaces the
offending column names before the write attempt so the error is
actionable rather than an opaque external exception.
Args:
frame: Collected DataFrame about to be written to Delta.
Raises:
SchemaError: When one or more columns carry ``pl.Null`` dtype.
"""
null_cols = [name for name, dtype in frame.schema.items() if isinstance(dtype, pl.Null)]
if not null_cols:
return
schema_repr = "\n".join(f" {n}: {d}" for n, d in frame.schema.items())
raise SchemaError(
f"Delta Lake does not support the Null dtype. "
f"Column(s) with Null dtype: {null_cols}.\n"
f"Cast each column to its intended type before writing "
f"(use an explicit schema declaration or .cast() in the pipeline step).\n"
f"Full frame schema:\n{schema_repr}"
)
[docs]
class PolarsTargetWriter(_WritePolicy[pl.LazyFrame, pl.DataFrame, PolarsPhysicalSchema]):
"""Polars target writer using delta-rs for Delta tables."""
def __init__(
self,
locator: str | TableLocator,
*,
route_resolver: TableRouteResolver | None = None,
missing_table_policy: MissingTablePolicy = MissingTablePolicy.SCHEMA_MODE,
file_locator: FileLocator | None = None,
audit_config: AuditConfig | None = None,
) -> None:
self._locator = _as_locator(locator)
super().__init__(
resolver=route_resolver or PathRouteResolver(self._locator),
missing_table_policy=missing_table_policy,
audit_config=audit_config,
)
self._file_writer = PolarsFileWriter()
self._file_locator = file_locator
[docs]
def append(
self,
frame: pl.LazyFrame,
table_ref: Any,
params_instance: Any,
*,
streaming: bool = False,
) -> None:
"""Append frame to table (legacy API, creates table on first write)."""
spec = AppendSpec(table_ref=table_ref, schema_mode=SchemaMode.EVOLVE)
self.write(frame, spec, params_instance, streaming=streaming)
[docs]
def to_frame(self, records: Sequence[LineageRecord], /) -> pl.LazyFrame:
"""Convert lineage records into a Polars LazyFrame."""
if not records:
raise ValueError("PolarsTargetWriter.to_frame requires at least one record.")
first = records[0]
record_type = type(first)
if any(type(record) is not record_type for record in records):
raise TypeError(
"PolarsTargetWriter.to_frame requires homogeneous record types per batch."
)
rows = [record.to_row() for record in records]
return pl.from_dicts(rows, schema=_polars_lineage_schema(first)).lazy()
# ====================================================================
# Schema Hooks
# ====================================================================
def _physical_schema(self, target: ResolvedTarget) -> PolarsPhysicalSchema | None:
"""Read physical schema from Delta log."""
path_target = self._as_path_target(target)
physical = read_delta_physical_schema(
path_target.location.uri,
path_target.location.storage_options,
)
return physical
def _align(
self,
frame: pl.LazyFrame,
existing_schema: PolarsPhysicalSchema | None,
mode: SchemaMode,
) -> pl.LazyFrame:
"""Align frame schema with existing."""
schema = existing_schema.schema if existing_schema is not None else None
return apply_schema_polars(frame, schema, mode)
def _materialize_for_write(self, frame: pl.LazyFrame, streaming: bool) -> pl.DataFrame:
"""Collect lazy frame to DataFrame before Delta write."""
df = frame.collect(engine="streaming" if streaming else "auto")
_check_null_dtype_columns(df)
return df
def _do_replace_partitions(
self,
frame: pl.LazyFrame,
target: ResolvedTarget,
spec: ReplacePartitionsSpec,
streaming: bool,
) -> None:
if not spec.streaming:
super()._do_replace_partitions(frame, target, spec, streaming)
return
existing = self._physical_schema(target)
if existing is None:
super()._do_replace_partitions(frame, target, spec, streaming)
return
aligned = self._align(frame, existing, spec.schema_mode)
_check_null_dtype_columns_lazy(aligned)
self._streaming_replace_partitions(aligned, target, spec)
def _streaming_replace_partitions(
self,
frame: pl.LazyFrame,
target: ResolvedTarget,
spec: ReplacePartitionsSpec,
) -> None:
path_target = self._as_path_target(target)
location = path_target.location
with spool_lazy_to_arrow_reader(frame) as spool:
partition_rows = partition_rows_from_spool(spool.spool_path, spec.partition_cols)
if not partition_rows:
_log.warning("replace_partitions empty frame", table=location.uri)
return
null_rows = [r for r in partition_rows if any(v is None for v in r.values())]
if null_rows:
sample = null_rows[:5]
more = (
f" (+{len(null_rows) - len(sample)} more)"
if len(null_rows) > len(sample)
else ""
)
raise ValueError(
f"replace_partitions: null values in partition columns "
f"{spec.partition_cols} for table {path_target}. Null partition "
f"values cannot be used as a replace predicate. Ensure all "
f"partition columns are non-null before writing. "
f"Offending rows: {sample}{more}"
)
predicate = _build_partition_predicate(iter(partition_rows), spec.partition_cols)
_log.info(
"delta write start",
mode="replace_partitions",
uri=location.uri,
rows="streaming",
cols=None,
partition_cols=list(spec.partition_cols),
predicate=predicate,
)
write_deltalake(
location.uri,
spool.reader,
mode="overwrite",
predicate=predicate,
schema_mode=self._delta_schema_mode(spec.schema_mode),
**self._write_kwargs(location),
)
_log_write_commit("replace_partitions", location)
def _predicate_to_sql(self, predicate: Any, params: Any) -> str:
"""Convert predicate to SQL (Polars uses internal SQL representation)."""
return predicate_to_sql(predicate, params)
# ====================================================================
# Write Hooks
# ====================================================================
def _create(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
schema_mode: SchemaMode,
partition_cols: tuple[str, ...] = (),
) -> None:
"""Create new Delta table."""
path_target = self._as_path_target(target)
_ = schema_mode
location = path_target.location
if partition_cols:
missing = tuple(c for c in partition_cols if c not in frame.columns)
if missing:
# Empty frame with no partition cols is a legitimate no-op
# (e.g. decomposition fallback when the source field is absent
# from the input schema). Mirrors _replace_partitions on the
# update path.
if frame.is_empty():
_log.warning(
"create skipped: empty frame missing partition columns",
table=location.uri,
missing=list(missing),
)
return
raise ValueError(
f"_create: partition columns not found in frame schema: "
f"{list(missing)}. Frame columns: {frame.columns}. Cannot "
f"create partitioned Delta table at {location.uri!r}."
)
self._warn_uc_first_create(path_target)
kwargs = self._write_kwargs(location)
_log_write_start("create", location, frame, partition_cols=partition_cols)
if partition_cols:
write_deltalake(
location.uri,
frame,
mode="overwrite",
partition_by=list(partition_cols),
schema_mode="overwrite",
**kwargs,
)
else:
write_deltalake(
location.uri,
frame,
mode="overwrite",
schema_mode="overwrite",
**kwargs,
)
_log_write_commit("create", location)
@staticmethod
def _delta_schema_mode(
schema_mode: SchemaMode,
) -> Literal["merge", "overwrite"] | None:
"""Map Loom SchemaMode to the delta-rs schema_mode string.
Returns None for STRICT so delta-rs applies its default behaviour
(reject writes that change the schema).
"""
if schema_mode is SchemaMode.OVERWRITE:
return "overwrite"
if schema_mode is SchemaMode.EVOLVE:
return "merge"
return None # STRICT — let delta-rs reject schema changes
def _append(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
schema_mode: SchemaMode,
) -> None:
"""Append to existing Delta table."""
path_target = self._as_path_target(target)
location = path_target.location
_log_write_start("append", location, frame)
write_deltalake(
location.uri,
frame,
mode="append",
schema_mode=self._delta_schema_mode(schema_mode),
**self._write_kwargs(location),
)
_log_write_commit("append", location)
def _replace(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
schema_mode: SchemaMode,
) -> None:
"""Overwrite existing Delta table."""
path_target = self._as_path_target(target)
location = path_target.location
_log_write_start("replace", location, frame)
write_deltalake(
location.uri,
frame,
mode="overwrite",
schema_mode=self._delta_schema_mode(schema_mode),
**self._write_kwargs(location),
)
_log_write_commit("replace", location)
def _replace_partitions(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
partition_cols: tuple[str, ...],
schema_mode: SchemaMode,
) -> None:
"""Overwrite partitions present in frame."""
path_target = self._as_path_target(target)
if frame.is_empty():
_log.warning(
"replace_partitions empty frame",
table=path_target.location.uri,
)
return
location = path_target.location
partition_rows: list[dict[str, Any]] = list(
frame.select(list(partition_cols)).unique().iter_rows(named=True)
)
null_rows = [r for r in partition_rows if any(v is None for v in r.values())]
if null_rows:
raise ValueError(
f"replace_partitions: null values in partition columns {partition_cols} for "
f"table {path_target}. Null partition values cannot be used as a replace "
f"predicate. Ensure all partition columns are non-null before writing. "
f"Offending rows: {null_rows}"
)
predicate = _build_partition_predicate(
iter(partition_rows),
partition_cols,
)
_log_write_start(
"replace_partitions",
location,
frame,
partition_cols=partition_cols,
predicate=predicate,
)
write_deltalake(
location.uri,
frame,
mode="overwrite",
predicate=predicate,
schema_mode=self._delta_schema_mode(schema_mode),
**self._write_kwargs(location),
)
_log_write_commit("replace_partitions", location)
def _replace_where(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
predicate: str,
schema_mode: SchemaMode,
) -> None:
"""Overwrite rows matching SQL predicate."""
path_target = self._as_path_target(target)
location = path_target.location
_log_write_start("replace_where", location, frame, predicate=predicate)
write_deltalake(
location.uri,
frame,
mode="overwrite",
predicate=predicate,
schema_mode=self._delta_schema_mode(schema_mode),
**self._write_kwargs(location),
)
_log_write_commit("replace_where", location)
def _upsert(
self,
frame: pl.DataFrame,
target: ResolvedTarget,
*,
spec: UpsertSpec,
existing_schema: PolarsPhysicalSchema,
) -> None:
"""Merge frame into target using Delta MERGE."""
path_target = self._as_path_target(target)
_ = existing_schema
location = path_target.location
# Collect partition combinations for pre-filter
combos = self._collect_partition_combos(
frame,
spec.partition_cols,
path_target.logical_ref.ref,
)
merge_plan = _build_merge_plan(
combos=combos,
spec=spec,
df_columns=tuple(frame.columns),
target_alias=TARGET_ALIAS,
source_alias=SOURCE_ALIAS,
)
dt = DeltaTable(location.uri, storage_options=location.storage_options or {})
_log_write_start(
"upsert",
location,
frame,
partition_cols=spec.partition_cols,
predicate=merge_plan.predicate,
)
(
dt.merge(
source=frame,
predicate=merge_plan.predicate,
source_alias=SOURCE_ALIAS,
target_alias=TARGET_ALIAS,
)
.when_matched_update(updates=merge_plan.update_set)
.when_not_matched_insert(updates=merge_plan.insert_values)
.execute()
)
_log_write_commit("upsert", location)
def _read_existing_data(
self,
target: ResolvedTarget,
frame: pl.LazyFrame,
spec: HistorifySpec,
) -> pl.DataFrame | None:
"""Read existing Delta data pruned to the partitions present in ``frame``.
When ``spec.partition_scope`` is set, distinct partition-column values
are extracted from ``frame`` (still lazy, one cheap ``collect``) and
pushed down as a filter before the Delta scan — so only the affected
Parquet files are opened. Without a partition scope the full table is
returned.
Args:
target: Resolved path target.
frame: Incoming LazyFrame; only partition columns are touched here.
spec: Historify spec; ``partition_scope`` carries column names.
Returns:
Collected DataFrame of existing rows, or ``None`` if the table has
not been created yet.
"""
from deltalake.exceptions import TableNotFoundError
path_target = self._as_path_target(target)
location = path_target.location
storage_options = location.storage_options or {}
try:
if not spec.partition_scope:
return pl.read_delta(location.uri, storage_options=storage_options or None)
partition_vals = frame.select(list(spec.partition_scope)).unique().collect()
scan = pl.scan_delta(location.uri, storage_options=storage_options or None)
filter_expr = _partition_filter(partition_vals, spec.partition_scope)
return scan.filter(filter_expr).collect()
except (TableNotFoundError, FileNotFoundError):
return None
def _historify(
self,
frame: pl.DataFrame,
existing: pl.DataFrame | None,
target: ResolvedTarget,
*,
spec: HistorifySpec,
params_instance: Any,
) -> HistorifyRepairReport | None:
"""Run SCD Type 2 transform and write result via existing write hooks."""
result = scd2_transform(PolarsHistorifyBackend(), frame, existing, spec, params_instance)
if existing is None:
self._create(
result,
target,
schema_mode=spec.schema_mode,
partition_cols=spec.partition_scope or (),
)
elif spec.partition_scope:
self._replace_partitions(
result, target, partition_cols=spec.partition_scope, schema_mode=spec.schema_mode
)
else:
self._replace(result, target, schema_mode=spec.schema_mode)
return None
def _write_file(
self,
frame: pl.LazyFrame,
spec: FileSpec,
*,
streaming: bool,
) -> None:
"""Write to file (CSV, JSON, Parquet), resolving alias if needed."""
resolved = self._resolve_file_spec(spec)
self._file_writer.write(frame, resolved, streaming=streaming)
def _resolve_file_spec(self, spec: FileSpec) -> FileSpec:
"""Return a FileSpec with a physical URI, resolving alias when required."""
if not spec.is_alias:
return spec
if self._file_locator is None:
raise ValueError(
f"IntoFile.alias({spec.path!r}) requires storage.files to be configured. "
"Set storage.files in your config YAML."
)
location = self._file_locator.locate(spec.path)
return FileSpec(
path=location.uri_template,
format=spec.format,
is_alias=False,
write_options=spec.write_options,
)
# ====================================================================
# Audit Hook
# ====================================================================
def _apply_audit_columns(
self,
frame: pl.LazyFrame,
write_ctx: WriteContext | None,
params_instance: Any,
audit: AuditConfig,
) -> pl.LazyFrame:
"""Inject audit columns into a Polars LazyFrame.
Skipped when ``audit.enabled`` is ``False`` or ``write_ctx`` is ``None``.
``process_run_id`` is only added when it is not ``None`` to avoid
nullable-column schema drift in Delta tables.
Args:
frame: Incoming lazy frame.
write_ctx: Step execution context.
params_instance: Concrete params; used to resolve ``from_param``.
audit: Audit configuration from ``StorageConfig``.
Returns:
Frame with audit columns appended, or the original frame unchanged.
"""
if not audit.enabled or write_ctx is None:
return frame
prefix = audit.prefix
cols: list[pl.Expr] = [
pl.lit(write_ctx.run_id).alias(f"{prefix}run_id"),
pl.lit(write_ctx.step).alias(f"{prefix}step"),
pl.lit(write_ctx.attempt).cast(pl.Int32).alias(f"{prefix}attempt"),
]
if write_ctx.process_run_id is not None:
cols.append(pl.lit(write_ctx.process_run_id).alias(f"{prefix}process_run_id"))
for col_def in audit.custom:
cols.append(_resolve_custom_column(col_def, params_instance))
return frame.with_columns(cols)
# ====================================================================
# Helpers
# ====================================================================
@staticmethod
def _warn_uc_first_create(target: PathTarget) -> None:
"""Warn about UC table creation limitations."""
if not target.location.uri.lower().startswith("uc://"):
return
_log.warning(
"Polars write first-create for Unity Catalog. "
"delta-rs writes Delta log data, but catalog registration is not guaranteed; "
"pre-create the table in UC (Spark SQL/Databricks) before this write.",
ref=target.logical_ref.ref,
uri=target.location.uri,
)
@staticmethod
def _as_path_target(target: ResolvedTarget) -> PathTarget:
"""Validate and narrow resolved target to path target for Polars writes."""
if isinstance(target, PathTarget):
return target
raise TypeError(
"PolarsTargetWriter requires path-resolved targets. "
"Configure routing so catalog refs resolve to PathTarget "
"(e.g. uc://catalog.schema.table)."
)
@staticmethod
def _write_kwargs(loc: TableLocation) -> dict[str, Any]:
"""Build write kwargs from TableLocation."""
return {
"storage_options": loc.storage_options or None,
"configuration": loc.delta_config or None,
"writer_properties": WriterProperties(**loc.writer) if loc.writer else None,
"target_file_size": loc.target_file_size,
"commit_properties": CommitProperties(**loc.commit) if loc.commit else None,
}
@staticmethod
def _collect_partition_combos(
frame: pl.DataFrame,
partition_cols: tuple[str, ...],
table_ref: str,
) -> list[dict[str, Any]]:
"""Collect unique partition value combinations."""
if not partition_cols:
_warn_no_partition_cols(table_ref)
return []
combos = frame.unique(subset=list(partition_cols)).select(list(partition_cols)).to_dicts()
_log_partition_combos(combos, table_ref)
return combos
__all__ = ["PolarsTargetWriter"]
def _resolve_custom_column(col_def: Any, params_instance: Any) -> pl.Expr:
"""Build a Polars literal expression for one custom audit column definition.
Args:
col_def: :class:`~loom.etl.storage._config.CustomColumnDef` instance.
params_instance: Concrete params object; used when ``from_param`` is set.
Returns:
A ``pl.lit(value).alias(name)`` expression ready for ``with_columns``.
Raises:
AttributeError: When ``from_param`` is set but the attribute is missing
from ``params_instance``.
"""
if col_def.value is not None:
return pl.lit(str(col_def.value)).alias(col_def.name)
val = str(getattr(params_instance, col_def.from_param))
return pl.lit(val).alias(col_def.name)
def _partition_filter(
partition_vals: pl.DataFrame,
partition_scope: tuple[str, ...],
) -> pl.Expr:
"""Build a Polars filter expression that matches exact partition combinations.
Instead of a loose column-wise ``is_in`` (which forms a bounding hyper-
rectangle and may read unwanted partition combinations), this builds an
``OR`` of ``AND`` equalities — one per distinct partition combo present in
``partition_vals``. Delta-rs pushes these predicates to the file-scan level,
opening only the Parquet files that belong to the exact combinations needed.
Args:
partition_vals: DataFrame with one column per partition key and one row
per distinct combination present in the incoming frame.
partition_scope: Ordered column names that form the partition key.
Returns:
A Polars expression suitable for use in ``.filter()``.
"""
combos = partition_vals.select(list(partition_scope)).unique().iter_rows(named=True)
exprs: list[pl.Expr] = []
for combo in combos:
and_expr: pl.Expr = pl.col(partition_scope[0]) == combo[partition_scope[0]]
for col in partition_scope[1:]:
and_expr = and_expr & (pl.col(col) == combo[col])
exprs.append(and_expr)
if not exprs:
return pl.lit(False)
result = exprs[0]
for expr in exprs[1:]:
result = result | expr
return result
def _polars_lineage_schema(record: LineageRecord) -> pl.Schema:
cols = get_lineage_schema(type(record))
return pl.Schema({c.name: loom_type_to_polars(c.dtype) for c in cols})