Source code for loom.streaming.nodes._fork

"""Graph-level fork declarations for terminal branching flows."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from loom.core.model import LoomFrozenStruct
from loom.streaming.core._message import StreamPayload
from loom.streaming.nodes._expr_eval import (
    PredicateSpec,
    SelectorSpec,
    evaluate_predicate,
    select_value,
)

if TYPE_CHECKING:
    from loom.streaming.graph._flow import Process
else:

    class Process:
        @classmethod
        def __class_getitem__(cls, item: object) -> type[Process]:
            del item
            return cls


InT = TypeVar("InT", bound=StreamPayload)


class ForkKind(StrEnum):
    """Routing family for a terminal fork declaration."""

    KEYED = "keyed"
    PREDICATE = "predicate"


[docs] class ForkRoute(LoomFrozenStruct, Generic[InT], frozen=True): """One predicate branch of a :class:`Fork` node. Pattern: Branch declaration. Args: when: Predicate expression or custom predicate. process: Terminal process executed when the predicate matches. """ when: PredicateSpec[InT] process: Process[InT, Any]
[docs] class Fork(Generic[InT]): """Terminal graph-level routing node that physically splits the stream. Pattern: Terminal branching. Unlike :class:`loom.streaming.nodes.Router`, branches are full independent sub-graphs. Each branch may contain any process node, including ``With`` and ``WithAsync``. ``Fork`` is terminal in the parent process: no nodes may follow it. """ __slots__ = ("_default", "_kind", "_predicate_routes", "_routes", "_selector") def __init__( self, *, kind: ForkKind, selector: SelectorSpec[InT] | None = None, routes: Mapping[object, Process[InT, Any]] | None = None, predicate_routes: Sequence[ForkRoute[InT]] = (), default: Process[InT, Any] | None = None, ) -> None: self._kind = kind self._selector = selector self._routes = dict(routes or {}) self._predicate_routes = tuple(predicate_routes) self._default = default
[docs] @classmethod def by( cls, selector: SelectorSpec[InT], branches: Mapping[object, Process[InT, Any]], *, default: Process[InT, Any] | None = None, ) -> Fork[InT]: """Declare key-based terminal branching. Args: selector: Path expression or custom selector. branches: Terminal processes keyed by selector result. default: Optional fallback process. Returns: Fork declaration. """ if not branches: raise ValueError("Fork.by requires at least one keyed route.") return cls(kind=ForkKind.KEYED, selector=selector, routes=branches, default=default)
[docs] @classmethod def when( cls, routes: Sequence[ForkRoute[InT]], *, default: Process[InT, Any] | None = None, ) -> Fork[InT]: """Declare ordered first-match terminal branching. Args: routes: Ordered predicate routes. default: Optional fallback process. Returns: Fork declaration. """ if not routes: raise ValueError("Fork.when requires at least one predicate route.") return cls( kind=ForkKind.PREDICATE, predicate_routes=routes, default=default, )
@property def kind(self) -> ForkKind: """Routing family for this fork.""" return self._kind @property def selector(self) -> SelectorSpec[InT] | None: """Key selector for ``Fork.by`` branches.""" return self._selector @property def routes(self) -> Mapping[object, Process[InT, Any]]: """Keyed branches for ``Fork.by``.""" return dict(self._routes) @property def predicate_routes(self) -> tuple[ForkRoute[InT], ...]: """Ordered predicate routes for ``Fork.when``.""" return self._predicate_routes @property def default(self) -> Process[InT, Any] | None: """Fallback process.""" return self._default
__all__ = ["Fork", "ForkKind", "ForkRoute", "evaluate_predicate", "select_value"]