Source code for loom.etl.testing.spark

"""Spark testing utilities and pytest fixtures for the Loom ETL framework.

Provides :class:`SparkTestSession` — a context manager that handles Spark
configuration, Delta extension setup, and teardown.

Import :mod:`loom.etl.testing.spark` explicitly from Spark-enabled test
projects, or register it via ``pytest_plugins`` / ``pytest -p`` in the
consuming suite. The module keeps Spark-only dependencies out of import time so
non-Spark installations can remain lightweight.

Usage in ``conftest.py``::

    import pytest
    from loom.etl.testing.spark import SparkTestSession, SparkStepRunner

    @pytest.fixture(scope="session")
    def spark():
        with SparkTestSession.start(app="my-tests", parallelism=1) as session:
            yield session

    @pytest.fixture
    def step_runner(spark):
        return SparkStepRunner(spark)
"""

from __future__ import annotations

import os
import tempfile
from collections.abc import Generator
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any

import pytest

from loom.etl.compiler import ETLCompiler
from loom.etl.declarative.source import SourceSpec
from loom.etl.declarative.target import TargetSpec
from loom.etl.executor import ETLExecutor
from loom.etl.testing._result import StepResult

if TYPE_CHECKING:
    import polars as pl
    from pyspark.sql import SparkSession


[docs] class SparkTestSession: """Context manager for a local PySpark + Delta test session. Handles Delta extension configuration and clean teardown. Use as a context manager to guarantee ``spark.stop()`` is always called:: with SparkTestSession.start(app="etl-tests") as spark: df = spark.createDataFrame([(1, "a")], ["id", "label"]) Args: session: Active :class:`pyspark.sql.SparkSession`. """ def __init__(self, session: SparkSession) -> None: self._session = session
[docs] @classmethod def start( cls, *, app: str = "loom-etl-test", parallelism: int = 1, memory: str = "1g", ivy_dir: str | Path | None = None, ) -> SparkTestSession: """Create and return a configured local SparkSession. Args: app: Spark application name shown in the UI. parallelism: ``spark.sql.shuffle.partitions`` — keep low (1–2) for unit tests to reduce overhead. memory: Driver heap size string (e.g. ``"1g"``). ivy_dir: Optional Ivy cache directory for Maven package resolution. When omitted in a constrained sandbox, a writable temp directory is selected automatically. Returns: A :class:`SparkTestSession` wrapping the active session. """ from pyspark.sql import SparkSession builder = ( SparkSession.builder.master("local[1]") .appName(app) .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") .config( "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog", ) .config("spark.driver.memory", memory) .config("spark.sql.shuffle.partitions", str(parallelism)) ) resolved_ivy_dir = _resolve_ivy_dir(ivy_dir) if resolved_ivy_dir is not None: resolved_ivy_dir.mkdir(parents=True, exist_ok=True) builder = builder.config("spark.jars.ivy", str(resolved_ivy_dir)) local_delta_jars = _resolve_local_delta_jars() if _sandbox_network_disabled() and local_delta_jars is not None: builder = builder.config("spark.jars", ",".join(str(jar) for jar in local_delta_jars)) session = builder.getOrCreate() else: from delta import configure_spark_with_delta_pip session = configure_spark_with_delta_pip(builder).getOrCreate() session.sparkContext.setLogLevel("ERROR") return cls(session)
def __enter__(self) -> SparkSession: return self._session def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self._session.stop() @property def session(self) -> SparkSession: """The underlying :class:`pyspark.sql.SparkSession`.""" return self._session
# --------------------------------------------------------------------------- # SparkStepRunner — in-memory Spark step test harness # --------------------------------------------------------------------------- class _SparkCapturingWriter: def __init__(self) -> None: self.frame: Any = None self.spec: TargetSpec | None = None def write( self, frame: Any, spec: TargetSpec, _params_instance: Any, *, streaming: bool = False, write_ctx: Any = None, ) -> None: _ = (streaming, write_ctx) self.frame = frame self.spec = spec class _SparkStubReader: def __init__(self, frames: dict[str, Any]) -> None: self._frames = frames def read( self, spec: SourceSpec, _params_instance: Any, /, ) -> Any: table_ref = getattr(spec, "table_ref", None) key = table_ref.ref if table_ref is not None else spec.alias return self._frames[key] def execute_sql(self, frames: dict[str, Any], query: str, /) -> Any: first = next(iter(frames.values()), None) if first is None: raise ValueError("StepSQL requires at least one source frame.") isolated = first.sparkSession.newSession() for name, frame in frames.items(): isolated.createDataFrame(frame.rdd, frame.schema).createOrReplaceTempView(name) return isolated.sql(query)
[docs] class SparkStepRunner: """In-memory test harness for Spark :class:`~loom.etl.ETLStep` subclasses. Seeds are plain Python tuples — no Spark dependency at definition time. Internally, each seed is converted to a ``pyspark.sql.DataFrame`` so ``execute()`` receives the same type as in production. No Delta I/O — reads and writes are captured in memory. Example:: def test_aggregate(loom_spark_runner): loom_spark_runner.seed("raw.orders", [(1, 10.0)], ["id", "amount"]) result = loom_spark_runner.run(AggregateStep, NoParams()) result.assert_count(1) result.show() """ def __init__(self, spark: SparkSession) -> None: self._spark = spark self._seeds: dict[str, tuple[list[tuple[Any, ...]], list[str]]] = {} self._writer = _SparkCapturingWriter()
[docs] def seed( self, ref: str, data: list[tuple[Any, ...]], columns: list[str], ) -> SparkStepRunner: """Register raw data under the logical table reference *ref*. Args: ref: Logical table reference, e.g. ``"raw.orders"``. data: Row data as a list of tuples. columns: Column names aligned with the tuple positions. Returns: ``self`` for fluent chaining. """ self._seeds[ref] = (list(data), list(columns)) return self
[docs] def run(self, step_cls: type[Any], params: Any) -> StepResult: """Compile and execute *step_cls* against the seeded tables. Args: step_cls: :class:`~loom.etl.ETLStep` subclass to execute. params: Concrete params instance for this run. Returns: :class:`~loom.etl.testing._result.StepResult` for assertions. Raises: KeyError: When a source table was not seeded. RuntimeError: When the step produced no output. """ frames = { ref: self._spark.createDataFrame(data, columns) for ref, (data, columns) in self._seeds.items() } plan = ETLCompiler().compile_step(step_cls) self._writer = _SparkCapturingWriter() ETLExecutor(_SparkStubReader(frames), self._writer).run_step(plan, params) raw = self._writer.frame if raw is None: raise RuntimeError("Step produced no output — check that target is declared.") return StepResult(_spark_frame_to_polars(raw))
@property def target_spec(self) -> TargetSpec: """Target spec from the last :meth:`run` call. Raises: RuntimeError: When :meth:`run` has not been called yet. """ if self._writer.spec is None: raise RuntimeError("No spec — call run() first.") return self._writer.spec
# --------------------------------------------------------------------------- # pytest fixture — opt-in via explicit import or pytest plugin registration # ---------------------------------------------------------------------------
[docs] @pytest.fixture(scope="session") def loom_spark_session() -> Generator[SparkSession, None, None]: """Local SparkSession with Delta Lake extensions. Scoped to the test session to amortise the ~5 s JVM startup cost. """ with SparkTestSession.start(app="loom-etl-tests", parallelism=1) as session: yield session
[docs] @pytest.fixture def loom_spark_runner(loom_spark_session: SparkSession) -> SparkStepRunner: """Fresh :class:`SparkStepRunner` per test — no Delta I/O, in-memory only. Depends on ``loom_spark_session`` — the SparkSession is scoped to the test session to amortise the ~5 s JVM startup cost. """ return SparkStepRunner(loom_spark_session)
def _resolve_ivy_dir(ivy_dir: str | Path | None) -> Path | None: if ivy_dir is not None: return Path(ivy_dir) from_env = os.getenv("LOOM_SPARK_IVY_DIR") if from_env: return Path(from_env) if os.getenv("CODEX_SANDBOX") == "seatbelt": return Path(tempfile.gettempdir()) / "loom-spark-ivy" return None def _sandbox_network_disabled() -> bool: return os.getenv("CODEX_SANDBOX_NETWORK_DISABLED") == "1" def _resolve_local_delta_jars() -> tuple[Path, ...] | None: jar_dir = Path.home() / ".ivy2" / "jars" if not jar_dir.exists(): return None delta_version = _delta_spark_version() delta_spark = _pick_jar(jar_dir, "io.delta_delta-spark_2.12-", delta_version) delta_storage = _pick_jar(jar_dir, "io.delta_delta-storage-", delta_version) antlr = _pick_jar(jar_dir, "org.antlr_antlr4-runtime-", None) if delta_spark is None or delta_storage is None or antlr is None: return None return (delta_spark, delta_storage, antlr) def _delta_spark_version() -> str | None: try: return version("delta-spark") except PackageNotFoundError: return None def _pick_jar(jar_dir: Path, prefix: str, expected_version: str | None) -> Path | None: if expected_version is not None: expected = jar_dir / f"{prefix}{expected_version}.jar" if expected.exists(): return expected candidates = sorted(jar_dir.glob(f"{prefix}*.jar")) if not candidates: return None return candidates[-1] def _spark_frame_to_polars(frame: Any) -> pl.DataFrame: import polars as pl columns = list(frame.columns) rows = frame.collect() records = [{col: row[col] for col in columns} for row in rows] if not records: return pl.DataFrame({col: [] for col in columns}) return pl.DataFrame(records)