"""Spark source reader - direct implementation without layers."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from loom.etl.backends._format_registry import resolve_format_handler
from loom.etl.backends._predicate import predicate_to_sql
from loom.etl.backends.spark._dtype import loom_type_to_spark
from loom.etl.declarative._format import Format
from loom.etl.declarative._read_options import CsvReadOptions, JsonReadOptions
from loom.etl.declarative.source import FileSourceSpec, SourceSpec, TableSourceSpec
from loom.etl.runtime.contracts import SourceReader
from loom.etl.schema._schema import ColumnSchema
from loom.etl.storage._file_locator import FileLocator
from loom.etl.storage._locator import TableLocator, _as_locator
from loom.etl.storage.routing import (
CatalogRouteResolver,
CatalogTarget,
PathRouteResolver,
TableRouteResolver,
)
_log = logging.getLogger(__name__)
[docs]
class SparkSourceReader(SourceReader):
"""Spark source reader - reads Delta tables and files directly."""
def __init__(
self,
spark: SparkSession,
locator: str | TableLocator | None = None,
*,
route_resolver: TableRouteResolver | None = None,
file_locator: FileLocator | None = None,
) -> None:
self._spark = spark
if route_resolver is None:
if locator is None:
resolver: TableRouteResolver = CatalogRouteResolver()
else:
resolver = PathRouteResolver(_as_locator(locator))
else:
resolver = route_resolver
self._resolver = resolver
self._file_locator = file_locator
[docs]
def read(self, spec: SourceSpec, params_instance: Any) -> DataFrame:
"""Read source spec and return DataFrame."""
if isinstance(spec, TableSourceSpec):
return self._read_table(spec, params_instance)
if isinstance(spec, FileSourceSpec):
return self._read_file(spec)
raise TypeError(
f"SparkSourceReader does not support source kind {spec.kind!r}. "
"TEMP sources are handled by CheckpointStore."
)
[docs]
def execute_sql(self, frames: dict[str, Any], query: str) -> DataFrame:
"""Execute SQL query against backend frames."""
return execute_sql(frames, query)
def _read_table(self, spec: TableSourceSpec, params: Any) -> DataFrame:
"""Read Delta table."""
target = self._resolver.resolve(spec.table_ref)
if isinstance(target, CatalogTarget):
df = self._spark.table(target.catalog_ref.ref)
else:
df = self._spark.read.format("delta").load(target.location.uri)
# Apply predicates
for pred in spec.predicates:
df = df.filter(predicate_to_sql(pred, params))
return self._finalize_source_frame(df, spec)
def _read_file(self, spec: FileSourceSpec) -> DataFrame:
"""Read file (CSV, JSON, Parquet), resolving alias if needed."""
path = self._resolve_file_path(spec)
df = self._read_file_by_format(path, spec.format, spec.read_options)
return self._finalize_source_frame(df, spec)
def _resolve_file_path(self, spec: FileSourceSpec) -> str:
"""Return the physical URI for *spec*, resolving alias when required."""
if not spec.is_alias:
return spec.path
if self._file_locator is None:
raise ValueError(
f"FromFile.alias({spec.path!r}) requires storage.files to be configured. "
"Set storage.files in your config YAML."
)
return self._file_locator.locate(spec.path).uri_template
def _finalize_source_frame(
self,
df: DataFrame,
spec: TableSourceSpec | FileSourceSpec,
) -> DataFrame:
"""Apply post-read transformations: columns, schema, json decode."""
if spec.columns:
df = df.select(list(spec.columns))
if spec.schema:
df = self._apply_source_schema(df, spec.schema)
if spec.json_columns:
df = self._apply_json_decode(df, spec.json_columns)
return df
def _read_file_by_format(self, path: str, format: Any, options: Any) -> DataFrame:
"""Dispatch to format-specific reader."""
readers: dict[Format, Callable[[str, Any], DataFrame]] = {
Format.DELTA: self._read_delta_file,
Format.CSV: self._read_csv_file,
Format.JSON: self._read_json_file,
Format.PARQUET: self._read_parquet_file,
Format.XLSX: self._read_xlsx_file,
}
reader = resolve_format_handler(format, readers)
return reader(path, options)
def _read_delta_file(self, path: str, _options: Any) -> DataFrame:
return self._spark.read.format("delta").load(path)
def _read_csv_file(self, path: str, options: Any) -> DataFrame:
csv_opts = options if isinstance(options, CsvReadOptions) else CsvReadOptions()
if csv_opts.skip_rows:
raise ValueError("Spark backend does not support skip_rows for CSV files.")
reader = (
self._spark.read.option("sep", csv_opts.separator)
.option("header", str(csv_opts.has_header).lower())
.option("encoding", csv_opts.encoding)
.option("inferSchema", "true")
)
if csv_opts.null_values:
reader = reader.option("nullValue", csv_opts.null_values[0])
if csv_opts.infer_schema_length is None:
reader = reader.option("samplingRatio", "1.0")
return reader.csv(path)
def _read_json_file(self, path: str, options: Any) -> DataFrame:
json_opts = options if isinstance(options, JsonReadOptions) else JsonReadOptions()
reader = self._spark.read.option("inferSchema", "true")
if json_opts.infer_schema_length is None:
reader = reader.option("samplingRatio", "1.0")
return reader.json(path)
def _read_parquet_file(self, path: str, _options: Any) -> DataFrame:
return self._spark.read.parquet(path)
def _read_xlsx_file(self, path: str, _options: Any) -> DataFrame:
_ = path
raise TypeError("Spark backend does not support XLSX format.")
def _apply_source_schema(self, df: DataFrame, schema: tuple[ColumnSchema, ...]) -> DataFrame:
"""Cast declared columns to their LoomDtype equivalents."""
if not schema:
return df
for col in schema:
df = df.withColumn(col.name, F.col(col.name).cast(loom_type_to_spark(col.dtype)))
return df
def _apply_json_decode(self, df: DataFrame, json_columns: tuple[Any, ...]) -> DataFrame:
"""Decode JSON string columns."""
if not json_columns:
return df
for jc in json_columns:
schema_ddl = loom_type_to_spark(jc.loom_type).simpleString()
df = df.withColumn(jc.column, F.from_json(F.col(jc.column), schema_ddl))
return df
def execute_sql(frames: dict[str, DataFrame], query: str) -> DataFrame:
"""Execute SQL query against temporary views created from Spark frames."""
first = next(iter(frames.values()), None)
if first is None:
raise ValueError("StepSQL requires at least one source frame.")
isolated = first.sparkSession.newSession()
for name, frame in frames.items():
isolated.createDataFrame(frame.rdd, frame.schema).createOrReplaceTempView(name)
return isolated.sql(query)
__all__ = ["SparkSourceReader", "execute_sql"]