Source code for loom.streaming.kafka.client._consumer

"""Raw Kafka consumer backed by confluent-kafka."""

from __future__ import annotations

from collections.abc import Iterable, Mapping
from typing import Literal, Protocol, cast

from confluent_kafka import Consumer as _Consumer
from confluent_kafka import Message as _RawMessage
from confluent_kafka import TopicPartition

from loom.core.observability.event import LifecycleEvent, Scope
from loom.core.observability.runtime import ObservabilityRuntime
from loom.streaming.kafka._config import ConsumerSettings
from loom.streaming.kafka._errors import KafkaCommitError, KafkaPollError
from loom.streaming.kafka._message import HEADER_CORRELATION_ID, HEADER_TRACE_ID
from loom.streaming.kafka._record import KafkaRecord


class _CommitMethod(Protocol):
    def __call__(
        self,
        *,
        asynchronous: bool = ...,
        offsets: list[TopicPartition] | None = ...,
    ) -> object: ...


[docs] class KafkaConsumerClient: """Confluent-backed raw Kafka consumer. Returns ``KafkaRecord[bytes]`` from Kafka. No deserialization — values are raw bytes as received from the broker. Args: settings: Typed consumer settings. observer: Optional observability observer. """ def __init__( self, settings: ConsumerSettings, obs: ObservabilityRuntime | None = None, ) -> None: self._consumer = _Consumer(settings.to_confluent_config()) self._consumer.subscribe(list(settings.topics)) self._obs = obs
[docs] def poll(self, timeout_ms: int) -> KafkaRecord[bytes] | None: """Read one raw byte record from Kafka. Args: timeout_ms: Maximum poll wait in milliseconds. Returns: One raw Kafka record or ``None`` when no record is available. Raises: KafkaPollError: If the backend poll fails or returns a broker error. """ try: message = self._consumer.poll(timeout_ms / 1000) except Exception as exc: raise KafkaPollError(str(exc)) from exc if message is None: return None if message.error() is not None: raise KafkaPollError(str(message.error())) record = _to_record(message) if self._obs is not None: self._obs.emit( LifecycleEvent.end( scope=Scope.TRANSPORT, name="kafka_consume", trace_id=_header_trace_id(record.headers), correlation_id=_header_correlation_id(record.headers), meta={"topic": record.topic}, ) ) return record
[docs] def commit(self, *, asynchronous: bool = False) -> None: """Commit consumed offsets. Args: asynchronous: Whether the backend may commit asynchronously. Raises: KafkaCommitError: If the backend commit fails. """ try: commit = cast(_CommitMethod, self._consumer.commit) commit(asynchronous=asynchronous) except Exception as exc: raise KafkaCommitError(str(exc)) from exc
[docs] def commit_offset(self, partitions: list[TopicPartition]) -> None: """Commit explicit Kafka offsets. Args: partitions: Kafka topic-partition offsets to commit. Raises: KafkaCommitError: If the backend commit fails. """ try: commit = cast(_CommitMethod, self._consumer.commit) commit(offsets=partitions, asynchronous=False) except Exception as exc: raise KafkaCommitError(str(exc)) from exc
[docs] def close(self) -> None: """Close the consumer and release resources.""" self._consumer.close()
def __enter__(self) -> KafkaConsumerClient: """Return self for context-manager usage.""" return self def __exit__(self, *exc: object) -> Literal[False]: """Close the consumer on context exit.""" try: self.close() except Exception: if exc[0] is None: raise return False
def _to_record(message: _RawMessage) -> KafkaRecord[bytes]: value_bytes = message.value() if value_bytes is None: raise TypeError("KafkaConsumerClient received a record without a value") _, timestamp_ms = message.timestamp() headers = _normalize_headers(message.headers()) key = message.key() normalized_key: bytes | str | None = None if key is None else key topic = message.topic() if topic is None: raise TypeError("KafkaConsumerClient received a record without a topic") return KafkaRecord( topic=topic, key=normalized_key, value=value_bytes, headers=headers, partition=message.partition(), offset=message.offset(), timestamp_ms=timestamp_ms if timestamp_ms >= 0 else None, ) def _header_trace_id(headers: dict[str, bytes]) -> str | None: raw = headers.get(HEADER_TRACE_ID) return raw.decode() if raw is not None else None def _header_correlation_id(headers: dict[str, bytes]) -> str | None: raw = headers.get(HEADER_CORRELATION_ID) return raw.decode() if raw is not None else None def _normalize_headers( raw_headers: Mapping[str, str | bytes | None] | Iterable[tuple[str, str | bytes | None]] | None, ) -> dict[str, bytes]: """Return Kafka headers without tombstone header values.""" headers: dict[str, bytes] = {} if raw_headers is None: return headers if isinstance(raw_headers, Mapping): iterable = cast(Iterable[tuple[str, str | bytes | None]], raw_headers.items()) else: iterable = raw_headers for header_key, header_value in iterable: if header_value is not None: headers[header_key] = ( header_value if isinstance(header_value, bytes) else header_value.encode("utf-8") ) return headers