"""BSON-to-Loom normalization helpers for MongoDB CDC."""
from __future__ import annotations
import base64
from collections.abc import Callable, Mapping, Sequence
from datetime import UTC, datetime
from typing import Protocol, cast
import msgspec
from loom.streaming.core._message import Message, MessageMeta
from loom.streaming.mongo._event import (
MongoBsonTimestamp,
MongoCDCEvent,
MongoCDCNamespace,
MongoDBRef,
MongoObjectId,
)
_MONGO_MESSAGE_TYPE = "loom.mongo.cdc"
_MAX_BSON_DEPTH = 64
class _SupportsBytes(Protocol):
def __bytes__(self) -> bytes:
"""Return a bytes representation."""
...
[docs]
def normalize_bson_value(value: object, _depth: int = 0) -> object:
"""Normalize one MongoDB/BSON runtime value into Loom-safe builtins."""
if _depth > _MAX_BSON_DEPTH:
raise ValueError(f"BSON document exceeds maximum nesting depth of {_MAX_BSON_DEPTH}.")
if value is None or isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, datetime):
return _datetime_to_epoch_ms(value)
if isinstance(value, Mapping):
return {str(key): normalize_bson_value(item, _depth + 1) for key, item in value.items()}
if isinstance(value, tuple):
return [normalize_bson_value(item, _depth + 1) for item in value]
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, memoryview, str)):
return [normalize_bson_value(item, _depth + 1) for item in value]
type_name = type(value).__name__
return _BSON_NORMALIZERS.get(type_name, _identity)(value)
[docs]
def build_mongo_cdc_event(change: Mapping[str, object]) -> MongoCDCEvent:
"""Build a Loom-safe Mongo CDC event from one raw change-stream document."""
resume_token = _normalize_resume_token(change.get("_id"))
event_id = _resume_token_event_id(resume_token)
operation_type = _required_str(change.get("operationType"), "operationType")
namespace = _build_namespace(change.get("ns"))
cluster_time = _build_cluster_time(change.get("clusterTime"))
wall_time_ms = _build_wall_time_ms(change.get("wallTime"), cluster_time)
lag_ms = _event_lag_ms(wall_time_ms, cluster_time)
raw_payload = normalize_bson_value(change)
if not isinstance(raw_payload, dict):
raise TypeError("Normalized Mongo change event must be a mapping.")
full_document_raw = raw_payload.get("fullDocument")
full_document = full_document_raw if isinstance(full_document_raw, dict) else None
update_description_raw = raw_payload.get("updateDescription")
update_description = (
update_description_raw if isinstance(update_description_raw, dict) else None
)
return MongoCDCEvent(
event_id=event_id,
operation_type=operation_type,
namespace=namespace,
resume_token=resume_token,
document_id=_document_id(change.get("documentKey")),
cluster_time=cluster_time,
wall_time_ms=wall_time_ms,
lag_ms=lag_ms,
full_document=full_document,
update_description=update_description,
raw_json=msgspec.json.encode(raw_payload).decode("utf-8"),
)
def _message_key(document_id: MongoObjectId | str | None) -> str | None:
"""Extract a partition key string from a normalized document identifier."""
return document_id.id if isinstance(document_id, MongoObjectId) else document_id
[docs]
def build_mongo_cdc_message(change: Mapping[str, object]) -> Message[MongoCDCEvent]:
"""Build a transport-neutral Loom message for one Mongo change event."""
event = build_mongo_cdc_event(change)
if event.wall_time_ms is not None:
produced_at_ms: int | None = event.wall_time_ms
elif event.cluster_time is not None:
produced_at_ms = event.cluster_time.seconds * 1000
else:
produced_at_ms = None
meta = MessageMeta(
message_id=event.event_id,
trace_id=event.event_id,
produced_at_ms=produced_at_ms,
message_type=_MONGO_MESSAGE_TYPE,
key=_message_key(event.document_id),
)
return Message(payload=event, meta=meta)
def _normalize_resume_token(value: object) -> dict[str, object]:
if not isinstance(value, Mapping):
raise TypeError("Mongo change event '_id' must be a mapping resume token.")
normalized = normalize_bson_value(value)
if not isinstance(normalized, dict):
raise TypeError("Normalized Mongo resume token must remain a mapping.")
return normalized
def _resume_token_event_id(resume_token: Mapping[str, object]) -> str:
token_data = resume_token.get("_data")
if isinstance(token_data, str) and token_data:
return token_data
return msgspec.json.encode(dict(resume_token)).decode("utf-8")
def _required_str(value: object, field: str) -> str:
if not isinstance(value, str) or not value:
raise TypeError(f"Mongo change event field '{field}' must be a non-empty string.")
return value
def _build_namespace(value: object) -> MongoCDCNamespace:
if not isinstance(value, Mapping):
raise TypeError("Mongo change event field 'ns' must be a mapping.")
db = _required_str(value.get("db"), "ns.db")
coll = value.get("coll")
if coll is not None and not isinstance(coll, str):
raise TypeError("Mongo change event field 'ns.coll' must be a string when present.")
return MongoCDCNamespace(db=db, coll=coll)
def _build_cluster_time(value: object) -> MongoBsonTimestamp | None:
if value is None:
return None
if isinstance(value, MongoBsonTimestamp):
return value
if type(value).__name__ == "Timestamp":
normalized = _normalize_timestamp_mapping(value)
return MongoBsonTimestamp(
seconds=_required_int(normalized.get("seconds"), "clusterTime.seconds"),
increment=_required_int(normalized.get("increment"), "clusterTime.increment"),
)
raise TypeError("Mongo change event field 'clusterTime' must be a BSON Timestamp.")
def _build_wall_time_ms(
value: object,
cluster_time: MongoBsonTimestamp | None,
) -> int | None:
if value is None:
return cluster_time.seconds * 1000 if cluster_time is not None else None
if isinstance(value, datetime):
return _datetime_to_epoch_ms(value)
if type(value).__name__ == "DatetimeMS":
return _normalize_datetime_ms(value)
normalized = normalize_bson_value(value)
if isinstance(normalized, int):
return normalized
raise TypeError("Mongo change event field 'wallTime' must be a datetime when present.")
def _normalize_optional_mapping(value: object) -> dict[str, object] | None:
if value is None:
return None
normalized = normalize_bson_value(value)
if not isinstance(normalized, dict):
raise TypeError("Normalized Mongo nested document must remain a mapping.")
return normalized
def _document_id(value: object) -> MongoObjectId | str | None:
if value is None:
return None
if not isinstance(value, Mapping):
raise TypeError("Mongo change event field 'documentKey' must be a mapping when present.")
raw = normalize_bson_value(value.get("_id"))
return raw if isinstance(raw, (MongoObjectId, str)) else None
def _normalize_timestamp_mapping(value: object) -> dict[str, object]:
seconds = getattr(value, "time", None)
increment = getattr(value, "inc", None)
if not isinstance(seconds, int) or not isinstance(increment, int):
raise TypeError("BSON Timestamp must expose integer 'time' and 'inc' attributes.")
return {"seconds": seconds, "increment": increment}
def _normalize_decimal128(value: object) -> str:
to_decimal = getattr(value, "to_decimal", None)
if callable(to_decimal):
return str(to_decimal())
return str(value)
def _normalize_binary(value: object) -> str:
raw = bytes(cast(_SupportsBytes, value))
return base64.b64encode(raw).decode("ascii")
def _normalize_dbref(value: object) -> object:
collection = getattr(value, "collection", None)
database = getattr(value, "database", None)
identifier = getattr(value, "id", None)
if not isinstance(collection, str):
raise TypeError("BSON DBRef must expose a string 'collection' attribute.")
if database is not None and not isinstance(database, str):
raise TypeError("BSON DBRef attribute 'database' must be a string when present.")
if identifier is None:
raise TypeError("BSON DBRef must have a non-None 'id' attribute.")
normalized_id = normalize_bson_value(identifier)
id_str = normalized_id.id if isinstance(normalized_id, MongoObjectId) else str(normalized_id)
return MongoDBRef(id=id_str, collection=collection, database=database)
def _required_int(value: object, field: str) -> int:
if not isinstance(value, int):
raise TypeError(f"Mongo change event field '{field}' must be an integer.")
return value
def _datetime_to_epoch_ms(value: datetime) -> int:
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
return int(value.astimezone(UTC).timestamp() * 1000)
def _current_time_ms() -> int:
return _datetime_to_epoch_ms(datetime.now(UTC))
def _event_time_ms(
wall_time_ms: int | None,
cluster_time: MongoBsonTimestamp | None,
) -> int | None:
if wall_time_ms is not None:
return wall_time_ms
if cluster_time is not None:
return cluster_time.seconds * 1000
return None
def _event_lag_ms(
wall_time_ms: int | None,
cluster_time: MongoBsonTimestamp | None,
) -> int | None:
event_time_ms = _event_time_ms(wall_time_ms, cluster_time)
if event_time_ms is None:
return None
return max(0, _current_time_ms() - event_time_ms)
def _normalize_objectid(value: object) -> object:
generation_time = getattr(value, "generation_time", None)
created_at_ms = (
_datetime_to_epoch_ms(generation_time) if isinstance(generation_time, datetime) else None
)
return MongoObjectId(id=str(value), created_at_ms=created_at_ms)
def _identity(value: object) -> object:
return value
def _normalize_datetime_ms(value: object) -> int:
# DatetimeMS.__int__() returns milliseconds since Unix epoch; safe for out-of-range years.
return int(value) # type: ignore[call-overload, no-any-return]
_BSON_NORMALIZERS: dict[str, Callable[[object], object]] = {
"ObjectId": _normalize_objectid,
"Timestamp": _normalize_timestamp_mapping,
"Decimal128": _normalize_decimal128,
"Binary": _normalize_binary,
"DBRef": _normalize_dbref,
"DatetimeMS": _normalize_datetime_ms,
}
__all__ = ["build_mongo_cdc_event", "build_mongo_cdc_message", "normalize_bson_value"]