"""OpenAPI schema helpers for FastAPI route binding.
Builds request/response schema fragments from Loom UseCase contracts while
keeping runtime execution transport-agnostic.
"""
from __future__ import annotations
from types import NoneType, UnionType
from typing import Any, Union, get_args, get_origin, get_type_hints
import msgspec
from loom.core.command.introspection import (
get_calculated_fields,
get_internal_fields,
get_patch_fields,
)
from loom.core.engine.plan import ExecutionPlan
from loom.core.repository.abc.query import PaginationMode, QuerySpec
from loom.rest.compiler import CompiledRoute
from loom.rest.constants import QueryParam
JsonSchema = dict[str, Any]
_PARAM_IN_QUERY = "query"
_PARAM_NAME = "name"
_PARAM_IN = "in"
_PARAM_REQUIRED = "required"
_PARAM_SCHEMA = "schema"
_PARAM_DESCRIPTION = "description"
# Re-export for backwards compatibility — consumers may import from this module.
QUERY_PARAM_PAGE = QueryParam.PAGE
QUERY_PARAM_LIMIT = QueryParam.LIMIT
QUERY_PARAM_PAGINATION = QueryParam.PAGINATION
QUERY_PARAM_AFTER = QueryParam.AFTER
QUERY_PARAM_CURSOR = QueryParam.CURSOR
QUERY_PARAM_SORT = QueryParam.SORT
QUERY_PARAM_DIRECTION = QueryParam.DIRECTION
QUERY_PARAM_PROFILE = QueryParam.PROFILE
QUERY_SPEC_PARAMETER_NAMES: tuple[str, ...] = (
QueryParam.PAGE,
QueryParam.LIMIT,
QueryParam.PAGINATION,
QueryParam.AFTER,
QueryParam.CURSOR,
QueryParam.SORT,
QueryParam.DIRECTION,
)
_SCHEMA_DEFS_KEY = "$defs"
_SCHEMA_REF_KEY = "$ref"
_LOCAL_DEFS_REF_PREFIX = "#/$defs/"
_COMPONENT_REF_PREFIX = "#/components/schemas/"
[docs]
def build_request_body_schema(
compiled_route: CompiledRoute,
component_registry: dict[str, JsonSchema] | None = None,
) -> JsonSchema | None:
"""Return OpenAPI ``requestBody`` schema for the route, if it has Input().
Args:
compiled_route: Fully resolved route from ``RestInterfaceCompiler``.
component_registry: Optional mutable dict that accumulates shared
component schemas. When provided, nested ``$defs`` are extracted
as ``#/components/schemas/`` entries.
Returns:
OpenAPI ``requestBody`` fragment, or ``None`` when the route has no
``Input()`` binding.
"""
plan = _get_execution_plan(compiled_route)
if plan is None or plan.input_binding is None:
return None
request_schema = _command_request_schema(plan.input_binding.command_type, component_registry)
if request_schema is None:
return None
return {
"required": True,
"content": {
"application/json": {
"schema": request_schema,
}
},
}
[docs]
def build_success_response_schema(
compiled_route: CompiledRoute,
component_registry: dict[str, JsonSchema] | None = None,
) -> JsonSchema | None:
"""Return OpenAPI response entry for the route success status, if resolvable.
Args:
compiled_route: Fully resolved route from ``RestInterfaceCompiler``.
component_registry: Optional mutable dict that accumulates shared
component schemas. When provided, nested ``$defs`` are extracted
as ``#/components/schemas/`` entries.
Returns:
OpenAPI response schema fragment, or ``None`` when the return type
cannot be resolved to a JSON Schema.
"""
response_schema = _use_case_response_schema(compiled_route.route.use_case, component_registry)
if response_schema is None:
return None
return {
"content": {
"application/json": {
"schema": response_schema,
}
}
}
[docs]
def build_query_parameters_schema(compiled_route: CompiledRoute) -> list[JsonSchema]:
"""Return OpenAPI query parameters inferred from the use-case contract."""
use_case_type = compiled_route.route.use_case
hints = get_type_hints(use_case_type.execute)
params: list[JsonSchema] = []
if _has_query_spec_parameter(hints):
params.extend(_query_spec_openapi_parameters(compiled_route.effective_pagination_mode))
if _has_profile_parameter(hints) and compiled_route.effective_expose_profile:
profile_param: JsonSchema = {
_PARAM_NAME: QUERY_PARAM_PROFILE,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string"},
_PARAM_DESCRIPTION: "Projection profile used by repository mappings.",
}
allowed = compiled_route.effective_allowed_profiles
if allowed:
profile_param[_PARAM_SCHEMA]["enum"] = list(allowed)
params.append(profile_param)
return params
def _get_execution_plan(compiled_route: CompiledRoute) -> ExecutionPlan | None:
plan = getattr(compiled_route.route.use_case, "__execution_plan__", None)
if isinstance(plan, ExecutionPlan):
return plan
return None
def _has_profile_parameter(type_hints: dict[str, Any]) -> bool:
return "profile" in type_hints
def _has_query_spec_parameter(type_hints: dict[str, Any]) -> bool:
for annotation in type_hints.values():
if annotation is QuerySpec:
return True
origin = get_origin(annotation)
args = get_args(annotation)
if origin in (UnionType, Union) and QuerySpec in args:
return True
return False
def _query_spec_openapi_parameters(default_mode: PaginationMode) -> list[JsonSchema]:
mode_values = [item.value for item in PaginationMode]
return [
{
_PARAM_NAME: QUERY_PARAM_PAGE,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "integer", "minimum": 1, "default": 1},
_PARAM_DESCRIPTION: "Offset page number.",
},
{
_PARAM_NAME: QUERY_PARAM_LIMIT,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "integer", "minimum": 1, "maximum": 1000, "default": 50},
_PARAM_DESCRIPTION: "Maximum rows per page.",
},
{
_PARAM_NAME: QUERY_PARAM_PAGINATION,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string", "enum": mode_values, "default": default_mode.value},
_PARAM_DESCRIPTION: "Pagination strategy.",
},
{
_PARAM_NAME: QUERY_PARAM_AFTER,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string"},
_PARAM_DESCRIPTION: "Cursor token for cursor pagination.",
},
{
_PARAM_NAME: QUERY_PARAM_CURSOR,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string"},
_PARAM_DESCRIPTION: "Alias for 'after'.",
},
{
_PARAM_NAME: QUERY_PARAM_SORT,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string"},
_PARAM_DESCRIPTION: "Field used for ordering.",
},
{
_PARAM_NAME: QUERY_PARAM_DIRECTION,
_PARAM_IN: _PARAM_IN_QUERY,
_PARAM_REQUIRED: False,
_PARAM_SCHEMA: {"type": "string", "enum": ["ASC", "DESC"], "default": "ASC"},
_PARAM_DESCRIPTION: "Sort direction.",
},
]
def _use_case_response_schema(
use_case_type: type[Any],
component_registry: dict[str, JsonSchema] | None = None,
) -> JsonSchema | None:
hints = get_type_hints(use_case_type.execute)
return_type = hints.get("return")
if return_type is None or return_type is Any:
return None
return _safe_schema(return_type, component_registry)
def _command_request_schema(
command_type: type[Any],
component_registry: dict[str, JsonSchema] | None = None,
) -> JsonSchema | None:
if not isinstance(command_type, type):
return None
if issubclass(command_type, msgspec.Struct):
public_struct = _build_public_command_struct(command_type)
return _safe_schema(public_struct, component_registry)
# Fallback for user-defined plain types (dataclass/pydantic/typing).
return _safe_schema(command_type, component_registry)
def _build_public_command_struct(command_type: type[Any]) -> type[Any]:
struct_fields = {field.name: field for field in msgspec.structs.fields(command_type)}
excluded = set(get_internal_fields(command_type)) | set(get_calculated_fields(command_type))
patch_fields = set(get_patch_fields(command_type))
definitions: list[tuple[Any, ...]] = []
for name, sf in struct_fields.items():
if name in excluded:
continue
annotation = _without_unset_type(sf.type)
if name in patch_fields:
definitions.append((name, _with_optional_none(annotation), None))
continue
if sf.default is msgspec.NODEFAULT:
definitions.append((name, annotation))
continue
definitions.append((name, annotation, sf.default))
use_camel = any(field.encode_name != field.name for field in struct_fields.values())
return msgspec.defstruct(
f"{command_type.__name__}Request",
definitions,
kw_only=True,
rename="camel" if use_camel else None,
)
def _safe_msgspec_schema(annotation: Any) -> JsonSchema | None:
try:
return msgspec.json.schema(annotation)
except (TypeError, ValueError):
return None
def _safe_pydantic_schema(annotation: Any) -> JsonSchema | None:
try:
from pydantic import TypeAdapter # lazy — only when pydantic is installed
adapter = TypeAdapter(annotation)
return adapter.json_schema()
except Exception:
return None
def _rewrite_to_component_refs(node: Any) -> Any:
"""Rewrite ``#/$defs/X`` references to ``#/components/schemas/X`` recursively.
Strips ``$defs`` keys and rewrites every ``$ref`` value that points to a
local definition so the resulting fragment is valid at the OpenAPI document
root where ``#/components/schemas/`` is resolvable.
"""
if isinstance(node, list):
return [_rewrite_to_component_refs(item) for item in node]
if not isinstance(node, dict):
return node
result: dict[str, Any] = {}
for k, v in node.items():
if k == _SCHEMA_DEFS_KEY:
continue
if k == _SCHEMA_REF_KEY and isinstance(v, str) and v.startswith(_LOCAL_DEFS_REF_PREFIX):
result[k] = _COMPONENT_REF_PREFIX + v.removeprefix(_LOCAL_DEFS_REF_PREFIX)
else:
result[k] = _rewrite_to_component_refs(v)
return result
def _safe_schema(
annotation: Any,
component_registry: dict[str, JsonSchema] | None = None,
) -> JsonSchema | None:
origin = get_origin(annotation)
if origin in (UnionType, Union):
union_members = [
_safe_schema(member, component_registry) for member in get_args(annotation)
]
if any(member is None for member in union_members):
return None
members = [member for member in union_members if member is not None]
if len(members) == 1:
return members[0]
return {"anyOf": members}
raw = _safe_msgspec_schema(annotation) or _safe_pydantic_schema(annotation)
if raw is None:
return None
if component_registry is None:
return raw
defs = raw.get(_SCHEMA_DEFS_KEY)
if isinstance(defs, dict):
for name, def_schema in defs.items():
component_registry[name] = _rewrite_to_component_refs(def_schema)
rewritten: JsonSchema = _rewrite_to_component_refs(raw)
return rewritten
def _without_unset_type(annotation: Any) -> Any:
if annotation is msgspec.UnsetType:
return Any
origin = get_origin(annotation)
if origin is None:
return annotation
if origin in (UnionType, Union):
args = tuple(arg for arg in get_args(annotation) if arg is not msgspec.UnsetType)
if not args:
return Any
result = args[0]
for arg in args[1:]:
result = result | arg
return result
return annotation
def _with_optional_none(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin in (UnionType, Union) and NoneType in get_args(annotation):
return annotation
if annotation is NoneType:
return annotation
return annotation | None