Source code for loom.streaming.bytewax.runner

"""StreamingRunner — production Bytewax runtime entry point for Loom flows."""

from __future__ import annotations

import logging
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from time import perf_counter
from typing import Any, Literal, Protocol, runtime_checkable

from bytewax.dataflow import Dataflow
from bytewax.recovery import RecoveryConfig
from bytewax.run import cli_main

from loom.core.async_bridge import AsyncBridge, build_backend_options
from loom.core.config import ConfigContext, ConfigKey
from loom.core.model import LoomFrozenStruct
from loom.core.observability.config import ObservabilityConfig
from loom.core.observability.event import LifecycleEvent, LifecycleStatus, Scope
from loom.core.observability.runtime import ObservabilityRuntime
from loom.core.runner import shutdown_runner
from loom.core.tracing import generate_trace_id
from loom.streaming.bytewax._adapter import build_dataflow_with_shutdown
from loom.streaming.bytewax._runtime_io import (
    build_commit_tracker,
    build_runtime_error_sinks,
    build_runtime_sink,
    build_runtime_source,
    build_runtime_terminal_sinks,
)
from loom.streaming.bytewax._sink_registry import SinkRegistry
from loom.streaming.compiler import CompiledPlan, compile_flow
from loom.streaming.core._errors import ErrorKind
from loom.streaming.graph._flow import StreamFlow

logger = logging.getLogger(__name__)


[docs] @runtime_checkable class ErrorSink(Protocol): """Bytewax sink protocol for ErrorEnvelope items.""" def build(self, step_id: str, worker_index: int, worker_count: int) -> Any: ...
@dataclass(frozen=True) class _PreparedStreamingRun: """Prepared Bytewax execution bundle. Args: dataflow: Built Bytewax dataflow ready for execution. shutdown: Callback that releases adapter-owned resources. """ dataflow: Dataflow shutdown: Callable[[], None]
[docs] class BytewaxRecoverySettings(LoomFrozenStruct, frozen=True): """Recovery settings for the Bytewax runtime. Args: db_dir: Local recovery directory shared by this execution. backup_interval_ms: Optional snapshot backup interval in milliseconds. """ db_dir: str backup_interval_ms: int | None = None
[docs] class BytewaxRuntimeConfig(LoomFrozenStruct, frozen=True): """Runtime settings for executing a Bytewax dataflow. **Delivery guarantee:** By default Kafka offsets are committed by the consumer's auto-commit policy and are not tied to downstream success. When ``enable.auto.commit`` is disabled in the consumer settings, Loom tracks item completion through the adapter and commits offsets only after the downstream branch finishes successfully. Args: workers_per_process: Number of worker threads in this process. process_id: Optional cluster process id. addresses: Optional cluster address list, including this process. epoch_interval_ms: Optional epoch duration in milliseconds. recovery: Optional recovery settings. async_backend: Anyio backend used for :class:`AsyncBridge` when the flow contains :class:`~loom.streaming.WithAsync` nodes. Accepted values: ``"asyncio"`` (default) or ``"trio"``. use_uvloop: Replace the asyncio event loop with uvloop for lower latency async I/O. Only effective when *async_backend* is ``"asyncio"`` and the platform is not Windows. force_shutdown_timeout_ms: Maximum milliseconds to wait for in-flight async tasks to finish during shutdown. When exceeded, the portal thread is abandoned and a warning is logged. ``None`` waits indefinitely (default). """ workers_per_process: int = 1 process_id: int | None = None addresses: tuple[str, ...] | None = None epoch_interval_ms: int | None = None recovery: BytewaxRecoverySettings | None = None async_backend: Literal["asyncio", "trio"] = "asyncio" use_uvloop: bool = False force_shutdown_timeout_ms: int | None = None
[docs] class StreamingRunner: """Wire a :class:`~loom.streaming.StreamFlow` declaration into the real Bytewax runtime. Args: _registry: Optional pre-built SinkRegistry for testing purposes. """ def __init__(self, *, _registry: SinkRegistry | None = None) -> None: self._flow: StreamFlow[Any, Any] | None = None self._plan: CompiledPlan | None = None self._runtime: BytewaxRuntimeConfig = BytewaxRuntimeConfig() self._observability_runtime: ObservabilityRuntime = ObservabilityRuntime.noop() self._log = logger self._shutdown: Callable[[], None] | None = None self._sink_registry: SinkRegistry = _registry or SinkRegistry() self._resolved_error_sinks: dict[ErrorKind, Any] = {} def _configure( self, flow: StreamFlow[Any, Any], plan: CompiledPlan, runtime: BytewaxRuntimeConfig, observability_runtime: ObservabilityRuntime, ) -> None: self._flow = flow self._plan = plan self._runtime = runtime self._observability_runtime = observability_runtime
[docs] @classmethod def from_context( cls, flow: StreamFlow[Any, Any], *, config: ConfigContext, observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingRunner: """Build a runner from a resolved config context.""" plan = compile_flow(flow, config=config) runtime = config.section_or_default( ConfigKey.STREAMING_RUNTIME, BytewaxRuntimeConfig, BytewaxRuntimeConfig(), ) observability_cfg = config.section_optional( ConfigKey.OBSERVABILITY, ObservabilityConfig, ) if observability_cfg is None: observability_cfg = ObservabilityConfig() obs = observability_runtime or ObservabilityRuntime.from_config(observability_cfg) runner = cls() runner._configure(flow, plan, runtime, obs) return runner
[docs] @classmethod def from_yaml( cls, flow: StreamFlow[Any, Any], path: str, *, observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingRunner: """Load config from YAML and build a runner.""" return cls.from_context( flow, config=ConfigContext.from_yaml(path), observability_runtime=observability_runtime, )
[docs] @classmethod def from_dict( cls, flow: StreamFlow[Any, Any], config: dict[str, Any], *, observability_runtime: ObservabilityRuntime | None = None, ) -> StreamingRunner: """Build a runner from a plain Python config mapping.""" return cls.from_context( flow, config=ConfigContext.from_dict(config), observability_runtime=observability_runtime, )
@property def _registry(self) -> SinkRegistry: """Internal sink registry for inspection in tests.""" return self._sink_registry
[docs] def register_sink(self, sink_class: type) -> None: """Register a sink class to be resolved from YAML config at run time. Args: sink_class: Class that satisfies the RegisteredSink protocol. Raises: TypeError: If sink_class does not implement RegisteredSink. """ self._sink_registry.register(sink_class)
[docs] def build_dataflow(self) -> Dataflow: """Assemble the Bytewax Dataflow and keep shutdown state.""" prepared = self.prepare_run() return prepared.dataflow
[docs] def prepare_run( self, *, error_sinks: Mapping[ErrorKind, ErrorSink] | None = None, ) -> _PreparedStreamingRun: """Prepare one executable dataflow and its shutdown callback. Releases any resources from a previous ``prepare_run()`` call before building the new dataflow, so calling this method twice is safe. Args: error_sinks: Optional mapping of :class:`~loom.streaming.core._errors.ErrorKind` to :class:`ErrorSink` instances. When provided, these sinks override the default error routing built from the compiled plan. Returns: A bundle containing the assembled :class:`~bytewax.dataflow.Dataflow` and a shutdown callable that releases adapter-owned resources. """ if self._plan is None: raise RuntimeError( "StreamingRunner has no compiled plan. " "Use run(flow=..., config_path=...) or a factory method " "(from_yaml / from_context / from_dict) before calling prepare_run." ) shutdown_runner(self) merged_error_sinks: Mapping[ErrorKind, Any] | None if error_sinks is not None: merged_error_sinks = error_sinks else: merged_error_sinks = self._resolved_error_sinks or None prepared = _prepare_run( self._plan, observability_runtime=self._observability_runtime, runtime=self._runtime, error_sinks=merged_error_sinks, ) self._shutdown = prepared.shutdown return prepared
[docs] def run( self, *, flow: StreamFlow[Any, Any] | None = None, config: ConfigContext | None = None, config_path: str | Path | None = None, error_sinks: Mapping[ErrorKind, Any] | None = None, runtime: BytewaxRuntimeConfig | None = None, ) -> None: """Execute the compiled dataflow with the real Bytewax runtime. Emits ``Scope.POLL_CYCLE`` around the full invocation (including preparation) and ``Scope.FLOW`` around the blocking ``cli_main`` call. Starts the Prometheus scrape server before execution when configured. Args: flow: Optional flow override. When provided together with ``config`` or ``config_path``, the runner is reconfigured before execution. config: Optional config context. Used together with ``flow`` to reconfigure the runner. config_path: Optional path to a YAML config file. Loaded and used like ``config`` when provided. error_sinks: Optional explicit error sinks. When provided, these take precedence over registry-resolved sinks. runtime: Optional runtime override. When omitted, uses the runner's config-loaded runtime settings. """ if config_path is not None: config = ConfigContext.from_yaml(str(config_path)) if flow is not None and config is not None: plan = compile_flow(flow, config=config) runtime_cfg = config.section_or_default( ConfigKey.STREAMING_RUNTIME, BytewaxRuntimeConfig, BytewaxRuntimeConfig() ) obs_cfg = config.section_optional(ConfigKey.OBSERVABILITY, ObservabilityConfig) obs = ObservabilityRuntime.from_config(obs_cfg or ObservabilityConfig()) self._configure(flow, plan, runtime_cfg, obs) if config is not None: bindings = self._sink_registry.resolve(config) self._resolved_error_sinks = { kind: b.sink for b in bindings if b.purpose == "errors" for kind in b.kinds } if error_sinks is not None: self._resolved_error_sinks = error_sinks # type: ignore[assignment] if self._plan is None: raise RuntimeError( "StreamingRunner has no compiled plan. " "Use run(flow=..., config_path=...) or a factory method " "(from_yaml / from_context / from_dict) before calling run." ) resolved_runtime = runtime or self._runtime run_id = generate_trace_id() self._observability_runtime.emit( LifecycleEvent.start( scope=Scope.POLL_CYCLE, name=self._plan.name, id=run_id, meta={"node_count": len(self._plan.nodes)}, ) ) started_at = perf_counter() status = LifecycleStatus.FAILURE try: self._observability_runtime.start_scrape_server() prepared = self.prepare_run() with self._observability_runtime.span( Scope.FLOW, self._plan.name, node_count=len(self._plan.nodes), ): cli_main(prepared.dataflow, **_runtime_kwargs(resolved_runtime)) # type: ignore[no-untyped-call] status = LifecycleStatus.SUCCESS except Exception: status = LifecycleStatus.FAILURE raise finally: duration_ms = int((perf_counter() - started_at) * 1000) self._observability_runtime.emit( LifecycleEvent.end( scope=Scope.POLL_CYCLE, name=self._plan.name, id=run_id, duration_ms=duration_ms, status=status, ) ) shutdown_runner(self)
[docs] def shutdown(self) -> None: """Release adapter resources after a run or failed build.""" if self._shutdown is not None: self._log.debug("shutting_down") self._shutdown() self._shutdown = None
def _runtime_kwargs(runtime: BytewaxRuntimeConfig) -> dict[str, Any]: kwargs: dict[str, Any] = { "workers_per_process": runtime.workers_per_process, "process_id": runtime.process_id, "addresses": list(runtime.addresses) if runtime.addresses is not None else None, "epoch_interval": _to_timedelta(runtime.epoch_interval_ms), "recovery_config": _recovery_config(runtime.recovery), } return kwargs def _recovery_config(recovery: BytewaxRecoverySettings | None) -> RecoveryConfig | None: if recovery is None: return None return RecoveryConfig( # type: ignore[no-untyped-call] Path(recovery.db_dir), backup_interval=_to_timedelta(recovery.backup_interval_ms), ) def _to_timedelta(value_ms: int | None) -> timedelta | None: if value_ms is None: return None return timedelta(milliseconds=value_ms) def _prepare_run( plan: CompiledPlan, *, observability_runtime: ObservabilityRuntime | None = None, source: Any | None = None, sink: Any | None = None, terminal_sinks: Mapping[tuple[int, ...], Any] | None = None, error_sinks: Mapping[ErrorKind, Any] | None = None, runtime: BytewaxRuntimeConfig | None = None, ) -> _PreparedStreamingRun: commit_tracker = build_commit_tracker(plan.source) if source is None: source = build_runtime_source(plan.source, commit_tracker) if sink is None and plan.output is not None: sink = build_runtime_sink(plan.output, commit_tracker) if terminal_sinks is None: terminal_sinks = build_runtime_terminal_sinks(plan.terminal_sinks, commit_tracker) if error_sinks is None: error_sinks = build_runtime_error_sinks(plan.error_routes, commit_tracker) bridge = _create_bridge(plan, runtime or BytewaxRuntimeConfig()) built = build_dataflow_with_shutdown( plan=plan, observability_runtime=observability_runtime, source=source, sink=sink, terminal_sinks=terminal_sinks, error_sinks=error_sinks, bridge=bridge, commit_tracker=commit_tracker, ) return _PreparedStreamingRun(dataflow=built.dataflow, shutdown=built.shutdown) def _create_bridge(plan: CompiledPlan, runtime: BytewaxRuntimeConfig) -> AsyncBridge | None: """Create an AsyncBridge configured from runtime settings, or None if not needed.""" if not plan.needs_async_bridge: return None return AsyncBridge( backend=runtime.async_backend, backend_options=build_backend_options(runtime.async_backend, runtime.use_uvloop), shutdown_timeout_ms=runtime.force_shutdown_timeout_ms, ) __all__ = [ "BytewaxRecoverySettings", "BytewaxRuntimeConfig", "ErrorSink", "StreamingRunner", ]