Source code for ccsds_ndm.model_quantity

# CCSDS-NDM: CCSDS Navigation Data Messages Read/Write Library
#
# Copyright (C) CCSDS-NDM Egemen Imre
#
# Licensed under GNU GPL v3.0. See LICENSE for more info.
"""
Validation and quantity support for xsdata model classes.

Patches ``__setattr__`` on container dataclasses so that:

- Already-wrapped values (e.g. ``LengthType``) pass through unchanged.
- pint / astropy Quantity objects are converted to the correct wrapper type
  with unit conversion and dimensional validation.
- Plain numbers, strings, and other invalid types raise ``TypeError``.
- ``None`` is allowed on optional fields.

Also patches a ``.q()`` method on wrapper types (dataclasses with ``value``
and ``units`` fields) that returns a pint or astropy Quantity, controlled by
a global mode flag.

Patching runs automatically when this module is imported.
"""

from __future__ import annotations

import dataclasses
import inspect
import warnings
from enum import Enum
from typing import get_type_hints

from ccsds_ndm.model_validate import (
    _patch_all_models,
    _unwrap_optional,
)

# ---------------------------------------------------------------------------
# Global quantity mode
# ---------------------------------------------------------------------------


[docs] class QuantityMode(Enum): """Quantity backend selection.""" PINT = "pint" ASTROPY = "astropy"
# Default to astropy; can be switched at runtime via set_quantity_mode() _quantity_mode: QuantityMode = QuantityMode.ASTROPY
[docs] def set_quantity_mode(mode: object) -> None: """Set the global quantity mode for ``.q()`` on wrapper types.""" global _quantity_mode if not isinstance(mode, QuantityMode): raise TypeError( f"Expected QuantityMode, got {type(mode).__name__}. " f"Use QuantityMode.PINT or QuantityMode.ASTROPY." ) _quantity_mode = mode
[docs] def get_quantity_mode() -> QuantityMode: """Return the current global quantity mode.""" return _quantity_mode
# --------------------------------------------------------------------------- # Global auto-convert flag # --------------------------------------------------------------------------- # When True, dimensionally compatible Quantities are automatically converted # to the field's default CCSDS unit instead of raising TypeError. _auto_convert: bool = False
[docs] def set_auto_convert(enabled: object) -> None: """Enable or disable automatic unit conversion on Quantity assignment.""" global _auto_convert if not isinstance(enabled, bool): raise TypeError(f"Expected bool, got {type(enabled).__name__}.") _auto_convert = enabled
[docs] def get_auto_convert() -> bool: """Return whether automatic unit conversion is enabled.""" return _auto_convert
# --------------------------------------------------------------------------- # CCSDS <-> pint / astropy unit string mappings # --------------------------------------------------------------------------- # pint does not recognise "rev/day" or "#/yr" natively; map them to # pint's canonical spelling. Keys cover both lower- and upper-case # variants that appear in NDM XML/KVN files. _CCSDS_TO_PINT: dict[str, str] = { "rev/day": "revolution/day", "rev/day**2": "revolution/day**2", "rev/day**3": "revolution/day**3", "REV/DAY": "revolution/day", "REV/DAY**2": "revolution/day**2", "REV/DAY**3": "revolution/day**3", "#/yr": "1/year", } # astropy uses "cycle" for revolutions and "d" for day. # "1/ER" (inverse Earth-radius, BSTAR unit) maps to astropy's built-in # earthRad constant. "SFU" (solar flux unit) has no astropy equivalent # and is listed in _UNSUPPORTED_UNITS instead. _CCSDS_TO_ASTROPY: dict[str, str] = { "rev/day": "cycle/d", "rev/day**2": "cycle/d**2", "rev/day**3": "cycle/d**3", "REV/DAY": "cycle/d", "REV/DAY**2": "cycle/d**2", "REV/DAY**3": "cycle/d**3", "1/ER": "1/earthRad", "#/yr": "1/yr", } # CCSDS unit strings that have no equivalent in either pint or astropy. # Fields with these units return a dimensionless Quantity with a warning. _UNSUPPORTED_UNITS: set[str] = {"SFU"} def _ccsds_to_pint(unit_str: str) -> str: """Convert a CCSDS unit string to its pint equivalent.""" return _CCSDS_TO_PINT.get(unit_str, unit_str) def _ccsds_to_astropy(unit_str: str) -> str: """Convert a CCSDS unit string to its astropy equivalent.""" return _CCSDS_TO_ASTROPY.get(unit_str, unit_str) # --------------------------------------------------------------------------- # Cached pint UnitRegistry # --------------------------------------------------------------------------- _pint_ureg = ( None # lazily initialised on first use to avoid importing pint at load time ) def _get_pint_ureg(): """Return a cached pint ``UnitRegistry`` with custom CCSDS units.""" global _pint_ureg if _pint_ureg is None: import pint _pint_ureg = pint.UnitRegistry() # ER = Earth Radius as used in SGP4 (Hoots & Roehrich 1980). # Intentionally differs from WGS84 (6378.137 km) to keep BSTAR values # consistent with the TLE convention that defines the unit. _pint_ureg.define("earth_radius = 6378.135 km = ER = R_earth") return _pint_ureg # --------------------------------------------------------------------------- # Type detection helpers # --------------------------------------------------------------------------- def _is_pint_quantity(value) -> bool: """Check if *value* is a pint Quantity without importing pint.""" t = type(value) return t.__module__.startswith("pint") and t.__name__ == "Quantity" def _is_astropy_quantity(value) -> bool: """Check if *value* is an astropy Quantity without importing astropy.""" t = type(value) return t.__module__.startswith("astropy") and t.__name__ == "Quantity" def _get_units_enum(wrapper_cls) -> type[Enum]: """Extract the units Enum class from a wrapper type's type hints.""" hints = get_type_hints(wrapper_cls) return _unwrap_optional(hints["units"]) # --------------------------------------------------------------------------- # Unit matching helpers # --------------------------------------------------------------------------- def _accepted_unit_strings(units_enum: type[Enum]) -> list[str]: """Return CCSDS unit strings for all supported (non-_UNSUPPORTED) members.""" return [m.value for m in units_enum if m.value not in _UNSUPPORTED_UNITS] def _match_pint_unit(value, units_enum: type[Enum]) -> tuple[object, bool]: """Return ``(matched_member_or_None, dimensions_ok)`` for a pint Quantity. Re-parses the incoming unit string through our registry to handle cross-registry comparisons safely. """ import pint u = _get_pint_ureg() try: # Re-parse the incoming unit through our registry; this normalises # quantities that may come from a different pint UnitRegistry instance, # making dimensionality and unit comparisons safe across registries. incoming = u.Quantity(f"1 {value.units}") except pint.errors.PintError: # Unrecognised unit string — cannot match anything return None, False matched = None dims_ok = False for member in units_enum: ccsds_str = member.value if ccsds_str in _UNSUPPORTED_UNITS: # Skip units that have no pint equivalent (e.g. non-SI CCSDS units) continue try: target = u.Quantity(f"1 {_ccsds_to_pint(ccsds_str)}") except pint.errors.PintError: # Mapping produced a unit string pint still can't parse — skip continue if incoming.dimensionality == target.dimensionality: dims_ok = True # at least one member shares the same dimensions if incoming.units == target.units: return member, True # exact unit match found return matched, dims_ok def _match_astropy_unit(value, units_enum: type[Enum]) -> tuple[object, bool]: """Return ``(matched_member_or_None, dimensions_ok)`` for an astropy Quantity.""" from astropy import units as astropy_u matched = None dims_ok = False for member in units_enum: ccsds_str = member.value if ccsds_str in _UNSUPPORTED_UNITS: # Skip units that have no astropy equivalent (e.g. non-SI CCSDS units) continue try: target = astropy_u.Unit(_ccsds_to_astropy(ccsds_str)) except (astropy_u.core.UnitsError, ValueError): # Mapping produced a unit string astropy still can't parse — skip continue if value.unit.physical_type == target.physical_type: dims_ok = True # at least one member shares the same physical type if value.unit == target: return member, True # exact unit match found return matched, dims_ok # --------------------------------------------------------------------------- # Auto-convert helpers # --------------------------------------------------------------------------- def _default_unit_member(units_enum: type[Enum]) -> Enum | None: """Return the first supported enum member (the NDM default unit).""" for member in units_enum: if member.value not in _UNSUPPORTED_UNITS: return member return None def _convert_pint(value, target_member) -> float: """Convert a pint Quantity to *target_member*'s unit; return magnitude.""" u = _get_pint_ureg() target_unit = u.Unit(_ccsds_to_pint(target_member.value)) converted = u.Quantity(f"{value.magnitude} {value.units}").to(target_unit) return float(converted.magnitude) def _convert_astropy(value, target_member) -> float: """Convert an astropy Quantity to *target_member*'s unit; return value.""" from astropy import units as astropy_u target_unit = astropy_u.Unit(_ccsds_to_astropy(target_member.value)) converted = value.to(target_unit) return float(converted.value) # --------------------------------------------------------------------------- # Field map construction # --------------------------------------------------------------------------- _FieldInfo = tuple # (wrapper_cls, units_enum_cls) def _build_field_map(cls, wrapper_types: set[type]) -> dict[str, _FieldInfo]: """Build ``{field_name: (wrapper_cls, units_enum_cls)}`` for *cls*.""" try: hints = get_type_hints(cls) except Exception: return {} field_map: dict[str, _FieldInfo] = {} for f in dataclasses.fields(cls): raw_type = hints.get(f.name) if raw_type is None: continue resolved = _unwrap_optional(raw_type) if not isinstance(resolved, type): continue # Only include fields whose type is a known wrapper (value+units dataclass) if resolved not in wrapper_types: continue units_enum = _get_units_enum(resolved) field_map[f.name] = (resolved, units_enum) return field_map # --------------------------------------------------------------------------- # __setattr__ factory for container types # --------------------------------------------------------------------------- def _resolve_quantity( value, name: str, owner_type: str, wrapper_cls, units_enum, matched, dims_ok: bool, unit_attr: str, mag_attr: str, convert_fn, ): """Convert a pint/astropy Quantity to the correct wrapper, or raise TypeError. Parameters ---------- unit_attr: Attribute name for the unit on *value* (``"units"`` for pint, ``"unit"`` for astropy). mag_attr: Attribute name for the magnitude on *value* (``"magnitude"`` for pint, ``"value"`` for astropy). convert_fn: Library-specific conversion helper (``_convert_pint`` or ``_convert_astropy``). """ if matched is not None: return wrapper_cls(value=float(getattr(value, mag_attr)), units=matched) unit_str = getattr(value, unit_attr) accepted = _accepted_unit_strings(units_enum) if dims_ok: default = _default_unit_member(units_enum) if default is not None and _auto_convert: try: return wrapper_cls(value=convert_fn(value, default), units=default) except Exception: raise TypeError( f"Unit '{unit_str}' could not be auto-converted " f"for field '{name}' on {owner_type}. " f"Accepted units for {wrapper_cls.__name__}: {accepted}." ) raise TypeError( f"Unit '{unit_str}' is not accepted for field '{name}' on {owner_type}. " f"Accepted units for {wrapper_cls.__name__}: {accepted}. " f"Convert your Quantity first." ) raise TypeError( f"Cannot assign {unit_str} quantity to field '{name}' on {owner_type} " f"which expects {wrapper_cls.__name__}. Dimensions are incompatible." ) def _make_setattr(field_map: dict[str, _FieldInfo]): """Return a ``__setattr__`` that validates and wraps Quantities. By default the unit of the incoming Quantity must exactly match one of the accepted CCSDS unit strings for the field. When ``_auto_convert`` is enabled, dimensionally compatible Quantities are silently converted to the field's default CCSDS unit instead of raising ``TypeError``. """ def __setattr__(self, name: str, value): info = field_map.get(name) if info is not None and value is not None: wrapper_cls, units_enum = info if isinstance(value, wrapper_cls): pass # already the correct wrapper type elif _is_pint_quantity(value): matched, dims_ok = _match_pint_unit(value, units_enum) value = _resolve_quantity( value, name, type(self).__name__, wrapper_cls, units_enum, matched, dims_ok, "units", "magnitude", _convert_pint, ) elif _is_astropy_quantity(value): matched, dims_ok = _match_astropy_unit(value, units_enum) value = _resolve_quantity( value, name, type(self).__name__, wrapper_cls, units_enum, matched, dims_ok, "unit", "value", _convert_astropy, ) else: # Not a wrapper, not a Quantity — reject with a helpful message raise TypeError( f"Field '{name}' on {type(self).__name__} expects " f"{wrapper_cls.__name__}, got " f"{type(value).__name__}({value!r}). " f"Use e.g. {wrapper_cls.__name__}(value=..., units=...) " f"or a pint/astropy Quantity." ) object.__setattr__(self, name, value) return __setattr__ # --------------------------------------------------------------------------- # .q() method for wrapper types # --------------------------------------------------------------------------- def _make_wrapper_q(): """Return a ``.q()`` method to be patched onto wrapper types.""" def q(self): """Return this value as a pint or astropy Quantity.""" mode = get_quantity_mode() # Extract the raw CCSDS unit string (e.g. "km", "1/ER") from the enum unit_str = ( self.units.value if hasattr(self, "units") and self.units is not None else None ) unsupported = unit_str in _UNSUPPORTED_UNITS if unsupported: warnings.warn( f"CCSDS unit '{unit_str}' has no equivalent; " f"returning dimensionless Quantity.", stacklevel=2, ) # Collapse the two None/unsupported guards into one variable effective_unit = None if (unit_str is None or unsupported) else unit_str if mode is QuantityMode.PINT: u = _get_pint_ureg() if effective_unit is None: return u.Quantity(self.value) return u.Quantity(f"{self.value} {_ccsds_to_pint(effective_unit)}") else: # QuantityMode.ASTROPY from astropy import units as astropy_u if effective_unit is None: return astropy_u.Quantity(self.value) return astropy_u.Quantity( f"{self.value} {_ccsds_to_astropy(effective_unit)}" ) return q # --------------------------------------------------------------------------- # Module discovery and patching # --------------------------------------------------------------------------- def _patch_module(module, wrapper_types: set[type], wrapper_q) -> None: """Patch all dataclasses in *module*.""" for _, cls in inspect.getmembers(module, inspect.isclass): if not dataclasses.is_dataclass(cls): continue # Guard against patching the same class twice (e.g. re-imported modules) if getattr(cls, "_ndm_quantity_patched", False): continue if cls in wrapper_types: # Wrapper types get a .q() method so users can extract a Quantity setattr(cls, "q", wrapper_q) setattr(cls, "_ndm_quantity_patched", True) else: field_map = _build_field_map(cls, wrapper_types) if not field_map: continue # Container types get a validating __setattr__ for their wrapper fields setattr(cls, "__setattr__", _make_setattr(field_map)) setattr(cls, "_ndm_quantity_patched", True) def _apply_quantity_support() -> None: """Patch all xsdata model classes. Called at import time.""" wrapper_q = _make_wrapper_q() def patch(submodules, wrapper_types): for mod in submodules: _patch_module(mod, wrapper_types, wrapper_q) _patch_all_models(patch) _apply_quantity_support()