Source code for loom.etl.backends.polars._writer

"""Polars target writer implementing _WritePolicy hooks."""

from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import Any

import polars as pl
from deltalake import CommitProperties, DeltaTable, WriterProperties, write_deltalake

from loom.etl.backends._historify._transform import scd2_transform
from loom.etl.backends._merge import (
    SOURCE_ALIAS,
    TARGET_ALIAS,
    _build_merge_plan,
    _build_partition_predicate,
    _log_partition_combos,
    _warn_no_partition_cols,
)
from loom.etl.backends._predicate import predicate_to_sql
from loom.etl.backends._write_policy import _WritePolicy
from loom.etl.backends.polars._historify import PolarsHistorifyBackend
from loom.etl.backends.polars._schema import (
    PolarsPhysicalSchema,
    apply_schema_polars,
    read_delta_physical_schema,
)
from loom.etl.declarative.target import SchemaMode
from loom.etl.declarative.target._file import FileSpec
from loom.etl.declarative.target._history import HistorifyRepairReport, HistorifySpec
from loom.etl.declarative.target._table import AppendSpec, UpsertSpec
from loom.etl.observability.records import ExecutionRecord, get_record_schema
from loom.etl.storage import (
    MissingTablePolicy,
    PathRouteResolver,
    PathTarget,
    ResolvedTarget,
    TableLocation,
    TableRouteResolver,
)
from loom.etl.storage._file_locator import FileLocator
from loom.etl.storage._locator import TableLocator, _as_locator

from ._dtype import loom_type_to_polars
from ._file_writer import PolarsFileWriter

_log = logging.getLogger(__name__)


[docs] class PolarsTargetWriter(_WritePolicy[pl.LazyFrame, pl.DataFrame, PolarsPhysicalSchema]): """Polars target writer using delta-rs for Delta tables.""" def __init__( self, locator: str | TableLocator, *, route_resolver: TableRouteResolver | None = None, missing_table_policy: MissingTablePolicy = MissingTablePolicy.SCHEMA_MODE, file_locator: FileLocator | None = None, ) -> None: self._locator = _as_locator(locator) super().__init__( resolver=route_resolver or PathRouteResolver(self._locator), missing_table_policy=missing_table_policy, ) self._file_writer = PolarsFileWriter() self._file_locator = file_locator
[docs] def append( self, frame: pl.LazyFrame, table_ref: Any, params_instance: Any, *, streaming: bool = False, ) -> None: """Append frame to table (legacy API, creates table on first write).""" spec = AppendSpec(table_ref=table_ref, schema_mode=SchemaMode.EVOLVE) self.write(frame, spec, params_instance, streaming=streaming)
[docs] def to_frame(self, records: Sequence[ExecutionRecord], /) -> pl.LazyFrame: """Convert execution records into a Polars LazyFrame.""" if not records: raise ValueError("PolarsTargetWriter.to_frame requires at least one record.") first = records[0] record_type = type(first) if any(type(record) is not record_type for record in records): raise TypeError( "PolarsTargetWriter.to_frame requires homogeneous record types per batch." ) rows = [record.to_row() for record in records] return pl.from_dicts(rows, schema=_polars_record_schema(first)).lazy()
# ==================================================================== # Schema Hooks # ==================================================================== def _physical_schema(self, target: ResolvedTarget) -> PolarsPhysicalSchema | None: """Read physical schema from Delta log.""" path_target = self._as_path_target(target) physical = read_delta_physical_schema( path_target.location.uri, path_target.location.storage_options, ) return physical def _align( self, frame: pl.LazyFrame, existing_schema: PolarsPhysicalSchema | None, mode: SchemaMode, ) -> pl.LazyFrame: """Align frame schema with existing.""" schema = existing_schema.schema if existing_schema is not None else None return apply_schema_polars(frame, schema, mode) def _materialize_for_write(self, frame: pl.LazyFrame, streaming: bool) -> pl.DataFrame: """Collect lazy frame to DataFrame before Delta write.""" return frame.collect(engine="streaming" if streaming else "auto") def _predicate_to_sql(self, predicate: Any, params: Any) -> str: """Convert predicate to SQL (Polars uses internal SQL representation).""" return predicate_to_sql(predicate, params) # ==================================================================== # Write Hooks # ==================================================================== def _create( self, frame: pl.DataFrame, target: ResolvedTarget, *, schema_mode: SchemaMode, partition_cols: tuple[str, ...] = (), ) -> None: """Create new Delta table.""" path_target = self._as_path_target(target) _ = schema_mode location = path_target.location self._warn_uc_first_create(path_target) kwargs = self._write_kwargs(location) if partition_cols: write_deltalake( location.uri, frame, mode="overwrite", partition_by=list(partition_cols), **kwargs, ) else: write_deltalake(location.uri, frame, mode="overwrite", **kwargs) def _append( self, frame: pl.DataFrame, target: ResolvedTarget, *, schema_mode: SchemaMode, ) -> None: """Append to existing Delta table.""" path_target = self._as_path_target(target) _ = schema_mode location = path_target.location write_deltalake(location.uri, frame, mode="append", **self._write_kwargs(location)) def _replace( self, frame: pl.DataFrame, target: ResolvedTarget, *, schema_mode: SchemaMode, ) -> None: """Overwrite existing Delta table.""" path_target = self._as_path_target(target) _ = schema_mode location = path_target.location write_deltalake(location.uri, frame, mode="overwrite", **self._write_kwargs(location)) def _replace_partitions( self, frame: pl.DataFrame, target: ResolvedTarget, *, partition_cols: tuple[str, ...], schema_mode: SchemaMode, ) -> None: """Overwrite partitions present in frame.""" path_target = self._as_path_target(target) _ = schema_mode if frame.is_empty(): _log.warning("replace_partitions table=%s has 0 rows — nothing written", path_target) return location = path_target.location predicate = _build_partition_predicate( frame.select(list(partition_cols)).unique().iter_rows(named=True), partition_cols, ) write_deltalake( location.uri, frame, mode="overwrite", predicate=predicate, **self._write_kwargs(location), ) def _replace_where( self, frame: pl.DataFrame, target: ResolvedTarget, *, predicate: str, schema_mode: SchemaMode, ) -> None: """Overwrite rows matching SQL predicate.""" path_target = self._as_path_target(target) _ = schema_mode location = path_target.location write_deltalake( location.uri, frame, mode="overwrite", predicate=predicate, **self._write_kwargs(location), ) def _upsert( self, frame: pl.DataFrame, target: ResolvedTarget, *, spec: UpsertSpec, existing_schema: PolarsPhysicalSchema, ) -> None: """Merge frame into target using Delta MERGE.""" path_target = self._as_path_target(target) _ = existing_schema location = path_target.location # Collect partition combinations for pre-filter combos = self._collect_partition_combos( frame, spec.partition_cols, path_target.logical_ref.ref, ) merge_plan = _build_merge_plan( combos=combos, spec=spec, df_columns=tuple(frame.columns), target_alias=TARGET_ALIAS, source_alias=SOURCE_ALIAS, ) dt = DeltaTable(location.uri, storage_options=location.storage_options or {}) ( dt.merge( source=frame, predicate=merge_plan.predicate, source_alias=SOURCE_ALIAS, target_alias=TARGET_ALIAS, ) .when_matched_update(updates=merge_plan.update_set) .when_not_matched_insert(updates=merge_plan.insert_values) .execute() ) def _read_existing_data( self, target: ResolvedTarget, frame: pl.LazyFrame, spec: HistorifySpec, ) -> pl.DataFrame | None: """Read existing Delta data pruned to the partitions present in ``frame``. When ``spec.partition_scope`` is set, distinct partition-column values are extracted from ``frame`` (still lazy, one cheap ``collect``) and pushed down as a filter before the Delta scan — so only the affected Parquet files are opened. Without a partition scope the full table is returned. Args: target: Resolved path target. frame: Incoming LazyFrame; only partition columns are touched here. spec: Historify spec; ``partition_scope`` carries column names. Returns: Collected DataFrame of existing rows, or ``None`` if the table has not been created yet. """ from deltalake.exceptions import TableNotFoundError path_target = self._as_path_target(target) location = path_target.location storage_options = location.storage_options or {} try: if not spec.partition_scope: return pl.read_delta(location.uri, storage_options=storage_options or None) partition_vals = frame.select(list(spec.partition_scope)).unique().collect() scan = pl.scan_delta(location.uri, storage_options=storage_options or None) filter_expr = _partition_filter(partition_vals, spec.partition_scope) return scan.filter(filter_expr).collect() except (TableNotFoundError, FileNotFoundError): return None def _historify( self, frame: pl.DataFrame, existing: pl.DataFrame | None, target: ResolvedTarget, *, spec: HistorifySpec, params_instance: Any, ) -> HistorifyRepairReport | None: """Run SCD Type 2 transform and write result via existing write hooks.""" result = scd2_transform(PolarsHistorifyBackend(), frame, existing, spec, params_instance) if existing is None: self._create( result, target, schema_mode=spec.schema_mode, partition_cols=spec.partition_scope or (), ) elif spec.partition_scope: self._replace_partitions( result, target, partition_cols=spec.partition_scope, schema_mode=spec.schema_mode ) else: self._replace(result, target, schema_mode=spec.schema_mode) return None def _write_file( self, frame: pl.LazyFrame, spec: FileSpec, *, streaming: bool, ) -> None: """Write to file (CSV, JSON, Parquet), resolving alias if needed.""" resolved = self._resolve_file_spec(spec) self._file_writer.write(frame, resolved, streaming=streaming) def _resolve_file_spec(self, spec: FileSpec) -> FileSpec: """Return a FileSpec with a physical URI, resolving alias when required.""" if not spec.is_alias: return spec if self._file_locator is None: raise ValueError( f"IntoFile.alias({spec.path!r}) requires storage.files to be configured. " "Set storage.files in your config YAML." ) location = self._file_locator.locate(spec.path) return FileSpec( path=location.uri_template, format=spec.format, is_alias=False, write_options=spec.write_options, ) # ==================================================================== # Helpers # ==================================================================== @staticmethod def _warn_uc_first_create(target: PathTarget) -> None: """Warn about UC table creation limitations.""" if not target.location.uri.lower().startswith("uc://"): return _log.warning( "Polars write first-create for Unity Catalog ref=%s uri=%s. " "delta-rs writes Delta log data, but catalog registration is not guaranteed; " "pre-create the table in UC (Spark SQL/Databricks) before this write.", target.logical_ref.ref, target.location.uri, ) @staticmethod def _as_path_target(target: ResolvedTarget) -> PathTarget: """Validate and narrow resolved target to path target for Polars writes.""" if isinstance(target, PathTarget): return target raise TypeError( "PolarsTargetWriter requires path-resolved targets. " "Configure routing so catalog refs resolve to PathTarget " "(e.g. uc://catalog.schema.table)." ) @staticmethod def _write_kwargs(loc: TableLocation) -> dict[str, Any]: """Build write kwargs from TableLocation.""" return { "storage_options": loc.storage_options or None, "configuration": loc.delta_config or None, "writer_properties": WriterProperties(**loc.writer) if loc.writer else None, "commit_properties": CommitProperties(**loc.commit) if loc.commit else None, } @staticmethod def _collect_partition_combos( frame: pl.DataFrame, partition_cols: tuple[str, ...], table_ref: str, ) -> list[dict[str, Any]]: """Collect unique partition value combinations.""" if not partition_cols: _warn_no_partition_cols(table_ref) return [] combos = frame.unique(subset=list(partition_cols)).select(list(partition_cols)).to_dicts() _log_partition_combos(combos, table_ref) return combos
__all__ = ["PolarsTargetWriter"] def _partition_filter( partition_vals: pl.DataFrame, partition_scope: tuple[str, ...], ) -> pl.Expr: """Build a Polars filter expression that matches exact partition combinations. Instead of a loose column-wise ``is_in`` (which forms a bounding hyper- rectangle and may read unwanted partition combinations), this builds an ``OR`` of ``AND`` equalities — one per distinct partition combo present in ``partition_vals``. Delta-rs pushes these predicates to the file-scan level, opening only the Parquet files that belong to the exact combinations needed. Args: partition_vals: DataFrame with one column per partition key and one row per distinct combination present in the incoming frame. partition_scope: Ordered column names that form the partition key. Returns: A Polars expression suitable for use in ``.filter()``. """ combos = partition_vals.select(list(partition_scope)).unique().iter_rows(named=True) exprs: list[pl.Expr] = [] for combo in combos: and_expr: pl.Expr = pl.col(partition_scope[0]) == combo[partition_scope[0]] for col in partition_scope[1:]: and_expr = and_expr & (pl.col(col) == combo[col]) exprs.append(and_expr) if not exprs: return pl.lit(False) result = exprs[0] for expr in exprs[1:]: result = result | expr return result def _polars_record_schema(record: ExecutionRecord) -> pl.Schema: cols = get_record_schema(type(record)) return pl.Schema({c.name: loom_type_to_polars(c.dtype) for c in cols})