Source code for loom.streaming.nodes._with

"""Scoped dependency declarations for streaming tasks."""

from __future__ import annotations

from collections.abc import Callable, Mapping
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from enum import StrEnum
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeGuard, TypeVar

from loom.core.config import Configurable
from loom.core.logger import LoggerPort, get_logger
from loom.core.model import LoomFrozenStruct, LoomStruct

if TYPE_CHECKING:
    from loom.streaming.graph._flow import Process

InT = TypeVar("InT", bound=LoomStruct | LoomFrozenStruct)
OutT = TypeVar("OutT", bound=LoomStruct | LoomFrozenStruct)

SyncContextDependency = AbstractContextManager[object]
AsyncContextDependency = AbstractAsyncContextManager[object]
ContextDependency = SyncContextDependency | AsyncContextDependency


[docs] class ResourceScope(StrEnum): """Context-manager lifecycle for dependencies declared through ``With``. Pattern: Wrapper scope. """ WORKER = "worker" BATCH = "batch"
[docs] class ContextFactory(Configurable): """Factory that creates a fresh context manager on demand. Pattern: Wrapper dependency. Use this when you need a new context-manager instance per batch (``scope=BATCH``) or when you want to configure the factory from YAML. Args: factory: Callable that returns a context-manager instance. Example:: db = ContextFactory(lambda: SessionLocal()) With(process=Process(ValidateOrder, IntoTopic("validated")), scope=BATCH, db=db) """ __slots__ = ("_factory", "_log") def __init__(self, factory: Callable[[], ContextDependency]) -> None: self._factory = factory self._log: LoggerPort | None = None
[docs] def create(self) -> ContextDependency: """Return a fresh context-manager instance.""" return self._factory()
@property def log(self) -> LoggerPort: """Structured logger bound to this factory class and module.""" logger = self._log if logger is None: cls = type(self) logger = get_logger(cls.__qualname__).bind( component="context_factory", class_name=cls.__qualname__, module=cls.__module__, factory_name=cls.__qualname__, ) self._log = logger return logger
class _WithBase(Generic[InT, OutT]): """Shared dependency scope and lifecycle management for streaming nodes. Pattern: Wrapper node. Classifies keyword dependencies into async context managers, sync context managers, :class:`ContextFactory` instances, and plain values. Subclasses are responsible for declaring and validating their primary execution target (a step or a process). Args: scope: Lifecycle used when opening context-manager dependencies. process: Inner process executed by the wrapper node. **dependencies: Named dependencies injected at execution time. """ __slots__ = ( "_async_contexts", "_context_factories", "_plain_deps", "_sync_contexts", "process", "scope", ) router_branch_safe: ClassVar[bool] = True def __init__( self, scope: ResourceScope = ResourceScope.WORKER, *, process: Process[Any, Any], **dependencies: object, ) -> None: self.process = process self.scope = scope sync_contexts: dict[str, SyncContextDependency] = {} async_contexts: dict[str, AsyncContextDependency] = {} context_factories: dict[str, ContextFactory] = {} plain_deps: dict[str, object] = {} for name, dependency in dependencies.items(): if isinstance(dependency, ContextFactory): context_factories[name] = dependency elif _is_sync_context_manager(dependency): sync_contexts[name] = dependency elif _is_async_context_manager(dependency): async_contexts[name] = dependency else: plain_deps[name] = dependency self._sync_contexts = sync_contexts self._async_contexts = async_contexts self._context_factories = context_factories self._plain_deps = plain_deps @property def sync_contexts(self) -> Mapping[str, SyncContextDependency]: """Sync context-manager dependencies keyed by injection name.""" return MappingProxyType(self._sync_contexts) @property def async_contexts(self) -> Mapping[str, AsyncContextDependency]: """Async context-manager dependencies keyed by injection name.""" return MappingProxyType(self._async_contexts) @property def context_factories(self) -> Mapping[str, ContextFactory]: """Context-manager factories keyed by injection name.""" return MappingProxyType(self._context_factories) @property def plain_deps(self) -> Mapping[str, object]: """Plain dependencies keyed by injection name.""" return MappingProxyType(self._plain_deps)
[docs] class With(_WithBase[InT, OutT]): """Declare a sync dependency scope around an inner process. Pattern: Sync wrapper. Each incoming message flows through the inner :class:`~loom.streaming.graph.Process` synchronously. If the last node of the inner process is an :class:`~loom.streaming.nodes._boundary.IntoTopic`, results are written directly to Kafka as each message completes — no ``ForEach`` or outer ``IntoTopic`` is required. The outer stream is drained after this node (``StreamShape.NONE``). Args: scope: Lifecycle used when opening context-manager dependencies. process: Inner process executed per message. **dependencies: Named sync context managers, factories, or plain values injected into step execution. Raises: TypeError: If an async context manager is passed directly. Example:: With( process=Process(ValidateOrder(), IntoTopic("validated", payload=Validated)), scope=ResourceScope.WORKER, db=SessionLocal(), ) """ __slots__ = () def __init__( self, scope: ResourceScope = ResourceScope.WORKER, *, process: Process[Any, Any], **dependencies: object, ) -> None: super().__init__(scope=scope, process=process, **dependencies) if self._async_contexts: names = ", ".join(self._async_contexts.keys()) raise TypeError( f"With only accepts sync context managers. " f"Async context manager(s) found: {names}. " f"Use WithAsync for async context managers." )
[docs] class WithAsync(_WithBase[InT, OutT]): """Declare an async dependency scope around an inner process. Pattern: Async wrapper. Each incoming message flows through the inner :class:`~loom.streaming.graph.Process` asynchronously. If the last node of the inner process is an :class:`~loom.streaming.nodes._boundary.IntoTopic`, results are written directly to Kafka as each message completes — no ``ForEach`` or outer ``IntoTopic`` is required. The outer stream is drained after this node (``StreamShape.NONE``). Args: scope: Lifecycle used when opening context-manager dependencies. max_concurrency: Maximum concurrent message executions. task_timeout_ms: Optional per-message wall-clock deadline in milliseconds. When set, each message execution is cancelled with :class:`TimeoutError` if it exceeds the deadline. process: Inner process executed per message. **dependencies: Named async context managers, factories, or plain values injected into step execution. Raises: TypeError: If a sync context manager is passed directly. ValueError: If *max_concurrency* is not positive. ValueError: If *task_timeout_ms* is not positive when provided. Example:: WithAsync( process=Process(FetchTask(), IntoTopic("results", payload=Result)), scope=ResourceScope.WORKER, max_concurrency=50, http=ContextFactory(lambda: httpx.AsyncClient()), ) """ __slots__ = ("max_concurrency", "task_timeout_ms") def __init__( self, scope: ResourceScope = ResourceScope.WORKER, max_concurrency: int = 10, task_timeout_ms: int | None = None, *, process: Process[Any, Any], **dependencies: object, ) -> None: if max_concurrency <= 0: raise ValueError("max_concurrency must be greater than zero") if task_timeout_ms is not None and task_timeout_ms <= 0: raise ValueError("task_timeout_ms must be greater than zero when provided") super().__init__(scope=scope, process=process, **dependencies) if self._sync_contexts: names = ", ".join(self._sync_contexts.keys()) raise TypeError( f"WithAsync only accepts async context managers. " f"Sync context manager(s) found: {names}. " f"Use With for sync context managers." ) self.max_concurrency = max_concurrency self.task_timeout_ms: int | None = task_timeout_ms
def _is_sync_context_manager(value: object) -> TypeGuard[SyncContextDependency]: return isinstance(value, AbstractContextManager) def _is_async_context_manager(value: object) -> TypeGuard[AsyncContextDependency]: return isinstance(value, AbstractAsyncContextManager) __all__ = [ "AsyncContextDependency", "ContextDependency", "ContextFactory", "ResourceScope", "SyncContextDependency", "With", "WithAsync", ]