Source code for loom.streaming.testing

"""Streaming testing helpers built on top of Bytewax testing sinks and sources.

Public API
----------
* :class:`StreamingTestRunner` — test-oriented runner with injectable input,
  captured output, and explicit error branch capture.

Typical usage::

    runner = StreamingTestRunner.from_flow(flow, config=cfg)
    runner.with_payloads([OrderPlaced(order_id="o-1")])
    runner.capture_errors(ErrorKind.WIRE)
    runner.run()
    assert len(runner.output) == 1
"""

from __future__ import annotations

import logging
from collections.abc import Mapping
from datetime import timedelta
from time import perf_counter
from typing import Any

import bytewax.testing as bytewax_testing
from bytewax.outputs import Sink

from loom.core.config import ConfigContext
from loom.core.observability.event import LifecycleEvent, LifecycleStatus, Scope
from loom.core.observability.runtime import ObservabilityRuntime
from loom.core.tracing import generate_trace_id
from loom.streaming import Message, MessageMeta
from loom.streaming.bytewax.runner import _prepare_run
from loom.streaming.compiler import CompiledPlan, compile_flow
from loom.streaming.compiler._plan import CompiledMultiSource, CompiledSingleSource
from loom.streaming.core._errors import ErrorKind
from loom.streaming.graph._flow import StreamFlow

logger = logging.getLogger(__name__)


[docs] class StreamingTestRunner: """Run a streaming flow with test doubles for input and output. The runner mirrors the production runner's dataflow preparation and swaps only the source and sinks for Bytewax testing primitives. Use :meth:`with_payloads` for ergonomic flow-author tests and :meth:`with_messages` when metadata must be controlled explicitly. Args: plan: Compiled plan produced by the compiler. observer: Optional flow observer. """ def __init__( self, plan: CompiledPlan, observability_runtime: ObservabilityRuntime | None = None, ) -> None: self._plan = plan self._observability_runtime = observability_runtime or ObservabilityRuntime.noop() self._input: list[Any] = [] self._output: list[Any] = [] self._errors: dict[ErrorKind, list[Any]] = {}
[docs] @classmethod def from_flow( cls, flow: StreamFlow[Any, Any], *, config: ConfigContext | Mapping[str, Any], observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingTestRunner: """Compile a flow and build a test runner from resolved config.""" plan = compile_flow(flow, config=_ensure_config_context(config)) return cls(plan, observability_runtime=observability_runtime)
[docs] @classmethod def from_yaml( cls, flow: StreamFlow[Any, Any], path: str, *, observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingTestRunner: """Load YAML config, compile the flow, and build a test runner.""" return cls.from_flow( flow, config=ConfigContext.from_yaml(path), observability_runtime=observability_runtime, )
[docs] @classmethod def from_dict( cls, flow: StreamFlow[Any, Any], config: dict[str, Any], *, observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingTestRunner: """Build a test runner from a plain Python config mapping.""" return cls.from_flow( flow, config=ConfigContext.from_dict(config), observability_runtime=observability_runtime, )
[docs] def with_payloads(self, items: list[Any]) -> StreamingTestRunner: """Replace input with payload-derived test messages. Payloads are wrapped into :class:`loom.streaming.Message` values using deterministic test metadata: - ``message_id``: ``test-<index>`` - ``topic``: first source topic declared by the flow """ topic = _source_topic_or_none(self._plan) self._input = [_test_message(topic, idx, item) for idx, item in enumerate(items)] return self
[docs] def with_messages(self, items: list[Any]) -> StreamingTestRunner: """Replace input with fully formed runtime messages or raw test records.""" self._input = list(items) return self
[docs] def capture_errors(self, *kinds: ErrorKind) -> StreamingTestRunner: """Enable capture for explicit error branches.""" for kind in kinds: self._errors.setdefault(kind, []) return self
[docs] def reset(self) -> StreamingTestRunner: """Clear captured input, output, and error buffers.""" self._input.clear() self._output.clear() self._errors.clear() return self
@property def output(self) -> list[Any]: """Captured items written to the main testing sink.""" return self._output @property def errors(self) -> dict[ErrorKind, list[Any]]: """Captured items written to configured error sinks.""" return self._errors
[docs] def run(self) -> None: """Execute the compiled dataflow with testing source and sinks.""" error_sinks: dict[ErrorKind, Sink[Any]] = { kind: bytewax_testing.TestingSink(items) for kind, items in self._errors.items() } terminal_sinks: dict[tuple[int, ...], Sink[Any]] = { path: bytewax_testing.TestingSink(self._output) for path in self._plan.terminal_sinks } prepared = _prepare_run( plan=self._plan, observability_runtime=self._observability_runtime, source=bytewax_testing.TestingSource(list(self._input)), sink=( bytewax_testing.TestingSink(self._output) if self._plan.output is not None or not self._plan.terminal_sinks else None ), terminal_sinks=terminal_sinks, error_sinks=error_sinks, ) run_id = generate_trace_id() self._observability_runtime.emit( LifecycleEvent.start( scope=Scope.POLL_CYCLE, name=self._plan.name, id=run_id, meta={"node_count": len(self._plan.nodes)}, ) ) started_at = perf_counter() status = LifecycleStatus.FAILURE try: bytewax_testing.run_main( prepared.dataflow, epoch_interval=timedelta(milliseconds=1), ) # type: ignore[no-untyped-call] status = LifecycleStatus.SUCCESS except Exception: status = LifecycleStatus.FAILURE raise finally: duration_ms = int((perf_counter() - started_at) * 1000) self._observability_runtime.emit( LifecycleEvent.end( scope=Scope.POLL_CYCLE, name=self._plan.name, id=run_id, duration_ms=duration_ms, status=status, ) ) logger.debug("shutting_down_test_runner") prepared.shutdown()
__all__ = ["StreamingTestRunner"] def _ensure_config_context(source: ConfigContext | Mapping[str, Any]) -> ConfigContext: """Normalize supported config inputs to a :class:`ConfigContext`.""" if isinstance(source, ConfigContext): return source return ConfigContext.from_dict(source) def _test_message(topic: str | None, idx: int, payload: Any) -> Any: return Message( payload=payload, meta=MessageMeta( message_id=f"test-{idx}", topic=topic, ), ) def _source_topic_or_none(plan: CompiledPlan) -> str | None: """Return the first source topic when the compiled source exposes one.""" if isinstance(plan.source, (CompiledSingleSource, CompiledMultiSource)): return next(iter(plan.source.topics), None) return None