Source code for loom.streaming.bytewax._adapter

"""Bytewax runtime adapter.

Translates a :class:`CompiledPlan` into a Bytewax :class:`Dataflow`,
wiring decode, node dispatch, encode, and output routing operators.

Requires ``bytewax`` to be installed.
"""

from __future__ import annotations

from collections.abc import Callable, Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Protocol, TypeAlias, cast, runtime_checkable

import bytewax.dataflow as _bytewax_dataflow
from bytewax.operators import branch
from bytewax.operators import input as bw_input
from bytewax.operators import map as bw_map
from bytewax.operators import output as bw_output
from bytewax.outputs import DynamicSink, StatelessSinkPartition

from loom.core.async_bridge import AsyncBridge
from loom.core.logger import get_logger
from loom.core.observability.runtime import ObservabilityRuntime
from loom.streaming.bytewax._resource_manager import ResourceManager
from loom.streaming.bytewax._runtime_io import build_runtime_terminal_sinks
from loom.streaming.bytewax.handlers.dispatcher import (
    _NODE_HANDLERS,
    _wire_process,
)
from loom.streaming.compiler import CompiledPlan
from loom.streaming.compiler._plan import CompiledMultiSource, CompiledSingleSource
from loom.streaming.core._errors import ErrorKind
from loom.streaming.core._message import Message
from loom.streaming.core._typing import StreamPayload
from loom.streaming.kafka._codec import MsgspecCodec
from loom.streaming.kafka._record import KafkaRecord
from loom.streaming.kafka._wire import (
    DecodeOk,
    DecodeResult,
    try_decode_multi_record,
    try_decode_record,
)
from loom.streaming.nodes._with import With, WithAsync

Stream: TypeAlias = Any
logger = get_logger(__name__)

__all__ = ["build_dataflow", "build_dataflow_with_shutdown", "_NODE_HANDLERS"]


@dataclass(frozen=True)
class _BuiltDataflow:
    """Bytewax dataflow plus adapter-owned shutdown callback."""

    dataflow: Any
    shutdown: Callable[[], None]


class _DropSinkPartition(StatelessSinkPartition[Any]):
    """Discard items routed to an unrouted error branch."""

    def write_batch(self, items: list[Any]) -> None:
        del items


class _DropSink(DynamicSink[Any]):
    """Build a no-op sink for unrouted error branches."""

    def build(
        self, step_id: str, worker_index: int, worker_count: int
    ) -> StatelessSinkPartition[Any]:
        del step_id, worker_index, worker_count
        return _DropSinkPartition()


[docs] def build_dataflow( plan: CompiledPlan, *, observability_runtime: ObservabilityRuntime | None = None, source: Any | None = None, sink: Any | None = None, terminal_sinks: Mapping[tuple[int, ...], Any] | None = None, error_sinks: Mapping[ErrorKind, Any] | None = None, ) -> Any: """Build a Bytewax Dataflow from a compiled plan.""" return build_dataflow_with_shutdown( plan, observability_runtime=observability_runtime, source=source, sink=sink, terminal_sinks=terminal_sinks, error_sinks=error_sinks, ).dataflow
[docs] def build_dataflow_with_shutdown( plan: CompiledPlan, *, observability_runtime: ObservabilityRuntime | None = None, source: Any | None = None, sink: Any | None = None, terminal_sinks: Mapping[tuple[int, ...], Any] | None = None, error_sinks: Mapping[ErrorKind, Any] | None = None, bridge: AsyncBridge | None = None, commit_tracker: Any | None = None, ) -> _BuiltDataflow: """Build a Bytewax Dataflow and expose its shutdown callback. Args: plan: Compiled flow plan. observability_runtime: Optional observability runtime for lifecycle events. source: Optional Bytewax source override (used in tests). sink: Optional Bytewax sink override (used in tests). terminal_sinks: Optional per-branch sink overrides. error_sinks: Optional per-kind error sink overrides. bridge: Pre-configured :class:`AsyncBridge`. When ``None``, a default asyncio bridge is created if the plan requires async execution. Pass an explicit bridge to control backend and uvloop settings. """ if terminal_sinks is None: terminal_sinks = build_runtime_terminal_sinks(plan.terminal_sinks, commit_tracker) resolved_runtime = observability_runtime or ObservabilityRuntime.noop() _bind_commit_tracker_object(source, commit_tracker) _bind_commit_tracker_object(sink, commit_tracker) _bind_commit_tracker_mapping(terminal_sinks, commit_tracker) _bind_commit_tracker_mapping(error_sinks, commit_tracker) resolved_bridge = bridge if bridge is not None else _maybe_create_bridge(plan) ctx = _BuildContext( plan=plan, bridge=resolved_bridge, flow_runtime=resolved_runtime, source=source, sink=sink, terminal_sinks=terminal_sinks, error_sinks=error_sinks, commit_tracker=commit_tracker, ) return _BuiltDataflow(dataflow=_assemble_dataflow(plan, ctx), shutdown=ctx.shutdown_all)
@runtime_checkable class _SupportsCommitBind(Protocol): """Runtime object that accepts a Kafka commit tracker.""" def bind_commit_tracker(self, tracker: Any) -> None: """Bind a commit tracker to this runtime object.""" def _bind_commit_tracker_object(item: object | None, commit_tracker: object | None) -> None: """Bind a commit tracker to one runtime object when supported.""" if item is None or commit_tracker is None: return if isinstance(item, _SupportsCommitBind): item.bind_commit_tracker(commit_tracker) def _bind_commit_tracker_mapping( items: Mapping[Any, Any] | None, commit_tracker: Any | None, ) -> None: """Bind a commit tracker to each runtime object in a mapping when supported.""" if items is None or commit_tracker is None: return for item in items.values(): _bind_commit_tracker_object(item, commit_tracker) def _assemble_dataflow(plan: CompiledPlan, ctx: _BuildContext) -> Any: """Assemble a Bytewax Dataflow from a pre-built context.""" flow = _bytewax_dataflow.Dataflow(plan.name) stream = _build_source_pipeline(flow, ctx) stream = _wire_process(stream, tuple(node.node for node in plan.nodes), ctx) _wire_output(stream, ctx) return flow class _BuildContext: """Wiring-phase state shared across operator builders.""" __slots__ = ( "plan", "bridge", "commit_tracker", "flow_runtime", "source", "sink", "error_sinks", "terminal_sinks", "resource_manager", "_path", ) def __init__( self, plan: CompiledPlan, bridge: AsyncBridge | None, flow_runtime: ObservabilityRuntime, source: Any | None = None, sink: Any | None = None, terminal_sinks: Mapping[tuple[int, ...], Any] | None = None, error_sinks: Mapping[ErrorKind, Any] | None = None, commit_tracker: Any | None = None, ) -> None: self.plan = plan self.bridge = bridge self.commit_tracker = commit_tracker self.flow_runtime = flow_runtime self.source = source self.sink = sink self.terminal_sinks = terminal_sinks or {} self.error_sinks = error_sinks or {} self.resource_manager = ResourceManager(bridge) self._path: tuple[int, ...] = () def wire_terminal(self, step_id: str, stream: Any) -> None: if self.sink is None: raise RuntimeError("Bytewax sink is required for terminal output wiring.") bw_output(step_id, stream, self.sink) def wire_branch_terminal(self, step_id: str, stream: Any, path: tuple[int, ...]) -> None: sink = self.terminal_sinks.get(path) if sink is None: return bw_output(_qualified_step_id(step_id, path), stream, sink) def wire_node_error(self, kind: ErrorKind, step_id: str, stream: Any) -> None: sink = self.error_sinks.get(kind) if sink is not None: bw_output(f"{step_id}_{kind.value}_errors", stream, sink) return logger.warning( "unrouted_error_drop_sink", flow=self.plan.name, kind=kind.value, step_id=step_id, ) bw_output(f"{step_id}_{kind.value}_dropped", stream, _DropSink()) def wire_flow_output(self, stream: Any, plan: CompiledPlan) -> None: if self.sink is None and plan.output is not None: raise RuntimeError("Bytewax sink is required for terminal output wiring.") if self.sink is not None: self.wire_terminal("output", stream) for kind in plan.error_routes: if kind not in self.error_sinks: raise RuntimeError(f"Bytewax sink is required for error route {kind.value}.") def wire_decode_error(self, stream: Any, plan: CompiledPlan) -> None: del plan if self.error_sinks.get(ErrorKind.WIRE) is not None: self.wire_node_error(ErrorKind.WIRE, "decode", stream) return if ErrorKind.WIRE in self.plan.error_routes: raise RuntimeError("Bytewax sink is required for WIRE error routing.") def inline_sink_partition_for( self, path: tuple[int, ...], ) -> StatelessSinkPartition[Any] | None: """Build an inline sink partition for the given path. Delegates to the runtime-wired Bytewax ``Sink`` for the path so that test doubles (e.g. ``TestingSink``) are honoured instead of always creating real Kafka producers. Args: path: Compiled path identifying the terminal sink. Returns: A ready-to-write ``StatelessSinkPartition``, or ``None`` if no sink is registered for *path*. """ sink = self.terminal_sinks.get(path) if sink is None: return None step_id = "inline_" + "_".join(str(p) for p in path) return cast(StatelessSinkPartition[Any], sink.build(step_id, 0, 1)) def manager_for( self, idx: int, node: With[StreamPayload, StreamPayload] | WithAsync[StreamPayload, StreamPayload], ) -> Any: """Get or create a resource manager for *node* at position *idx*.""" return self.resource_manager.manager_for(idx, node) def session_manager_for(self, config: Any) -> Any: """Get or create a shared SQLAlchemy session manager for one config.""" return self.resource_manager.session_manager_for(config) @property def current_path(self) -> tuple[int, ...]: """Return the current wiring path inside the process tree.""" return self._path @contextmanager def enter_path(self, path: tuple[int, ...]) -> Iterator[None]: """Temporarily set the current wiring path.""" previous = self._path self._path = path try: yield finally: self._path = previous def wire_process( self, stream: Any, nodes: tuple[object, ...], *, path_prefix: tuple[int, ...] = (), ) -> Any: """Wire one nested process subtree.""" return _wire_process(stream, nodes, self, path_prefix=path_prefix) def shutdown_all(self) -> None: """Shutdown all resource managers.""" self.resource_manager.shutdown_all() def _build_source_pipeline(flow: Any, ctx: _BuildContext) -> Stream: """Build the source-side pipeline up to the first decoded Message stream.""" if ctx.source is None: raise RuntimeError("Bytewax source is required to build a runtime dataflow.") source = ctx.source if not ctx.plan.source.needs_decode: return bw_input("source", flow, source) codec: MsgspecCodec[Any] = MsgspecCodec() if not isinstance(ctx.plan.source, (CompiledSingleSource, CompiledMultiSource)): raise TypeError(f"Expected a Kafka compiled source, got {type(ctx.plan.source).__name__}.") strategy = ctx.plan.source.decode_strategy step_id = f"decode_{strategy}" stream: Stream = bw_input("source", flow, source) decoded = _decode_source_stream(stream, ctx, codec, step_id) decoded_branch = _split_decode_results(decoded, step_id) ctx.wire_decode_error(decoded_branch.falses, ctx.plan) return bw_map(f"{step_id}_message", decoded_branch.trues, _decode_ok_message) def _decode_source_stream( stream: Stream, ctx: _BuildContext, codec: MsgspecCodec[Any], step_id: str, ) -> Stream: """Map raw source items into decode results without raising wire errors.""" return bw_map(step_id, stream, lambda item: _decode_source_record(item, ctx, codec)) def _split_decode_results(stream: Stream, step_id: str) -> Any: """Split decode results into successful messages and wire errors.""" return branch(f"{step_id}_is_ok", stream, _is_decode_ok) def _wire_output(stream: Any, ctx: _BuildContext) -> None: """Wire the output sink and error routes.""" ctx.wire_flow_output(stream, ctx.plan) def _qualified_step_id(step_id: str, path: tuple[int, ...]) -> str: if not path: return step_id return "_".join((step_id, *map(str, path))) def _maybe_create_bridge(plan: CompiledPlan) -> AsyncBridge | None: """Create a default asyncio AsyncBridge if the plan requires async execution. Used as a fallback when no pre-configured bridge is supplied to :func:`build_dataflow_with_shutdown` — e.g. in test helpers or direct adapter use. Production runners should pass an explicit bridge created via :func:`~loom.streaming.bytewax.runner._create_bridge` so that backend and uvloop settings from :class:`BytewaxRuntimeConfig` are applied. """ if not plan.needs_async_bridge: return None return AsyncBridge() def _decode_source_record( payload: Any, ctx: _BuildContext, codec: MsgspecCodec[Any], ) -> DecodeResult[StreamPayload]: """Decode source records into DSL messages without raising decode errors.""" if isinstance(payload, Message): return DecodeOk(message=cast(Message[StreamPayload], payload)) if isinstance(payload, KafkaRecord): record = cast(KafkaRecord[bytes], payload) source = ctx.plan.source if not source.needs_decode: raise TypeError("Mongo CDC sources must emit Message values, not KafkaRecord items.") if isinstance(source, CompiledMultiSource): return try_decode_multi_record(record, source.dispatch, codec) if not isinstance(source, CompiledSingleSource): raise TypeError(f"Expected CompiledSingleSource, got {type(source).__name__}.") return try_decode_record(record, source.payload_type, codec) raise TypeError(f"Expected Message or KafkaRecord, got {type(payload).__name__}.") def _is_decode_ok(result: DecodeResult[StreamPayload]) -> bool: """Return whether a source decode result can continue through the flow.""" return isinstance(result, DecodeOk) def _decode_ok_message(result: DecodeResult[StreamPayload]) -> Message[StreamPayload]: """Unwrap a successful source decode result.""" if isinstance(result, DecodeOk): return result.message raise TypeError(f"Expected DecodeOk, got {type(result).__name__}.")