Source code for exo_skryer.registry_cia

"""
registry_cia.py
===============
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from typing import List, Tuple, Optional
from pathlib import Path

import jax.numpy as jnp
import numpy as np
import zarr

__all__ = [
    "CiaRegistryEntry",
    "reset_registry",
    "has_cia_data",
    "load_cia_registry",
    "cia_species_names",
    "cia_master_wavelength",
    "cia_sigma_cube",
    "cia_temperature_grid",
    "cia_temperature_grids",
    "cia_log10_temperature_grids",
    "cia_runtime_species_order",
    "cia_kept_pair_indices",
    "cia_pair_species_i",
    "cia_pair_species_j",
    "cia_retained_sigma_cube",
    "cia_retained_temperature_grids",
    "cia_retained_log10_temperature_grids",
]


# Dataclass containing the CIA table data
# Note: During preprocessing, all arrays are NumPy (CPU)
# They get converted to JAX (device) only at the final cache creation step
# Float64 throughout for grids and cross sections.
[docs] @dataclass(frozen=True) class CiaRegistryEntry: name: str idx: int temperatures: np.ndarray # NumPy during preprocessing (float64) wavelengths: np.ndarray # NumPy during preprocessing (float64) cross_sections: np.ndarray # NumPy during preprocessing (float64)
# Global scope cache data array _CIA_SPECIES_NAMES: Tuple[str, ...] = () # Lightweight: only species names (few bytes) _CIA_SIGMA_CACHE: jnp.ndarray | None = None _CIA_TEMPERATURE_CACHE: jnp.ndarray | None = None _CIA_WAVELENGTH_CACHE: jnp.ndarray | None = None _CIA_LOG10_TEMPERATURE_CACHE: jnp.ndarray | None = None _CIA_RUNTIME_SPECIES_ORDER: Tuple[str, ...] = () _CIA_KEPT_PAIR_INDICES_CACHE: jnp.ndarray | None = None _CIA_PAIR_SPECIES_I_CACHE: jnp.ndarray | None = None _CIA_PAIR_SPECIES_J_CACHE: jnp.ndarray | None = None _CIA_RETAINED_SIGMA_CACHE: jnp.ndarray | None = None _CIA_RETAINED_TEMPERATURE_CACHE: jnp.ndarray | None = None _CIA_RETAINED_LOG10_TEMPERATURE_CACHE: jnp.ndarray | None = None # Clear cache helper function def _clear_cache(): cia_species_names.cache_clear() cia_master_wavelength.cache_clear() cia_temperature_grids.cache_clear() cia_temperature_grid.cache_clear() cia_sigma_cube.cache_clear() cia_log10_temperature_grids.cache_clear() cia_runtime_species_order.cache_clear() cia_kept_pair_indices.cache_clear() cia_pair_species_i.cache_clear() cia_pair_species_j.cache_clear() cia_retained_sigma_cube.cache_clear() cia_retained_temperature_grids.cache_clear() cia_retained_log10_temperature_grids.cache_clear() # Reset all registry values
[docs] def reset_registry(): global _CIA_SPECIES_NAMES, _CIA_SIGMA_CACHE, _CIA_TEMPERATURE_CACHE global _CIA_WAVELENGTH_CACHE, _CIA_LOG10_TEMPERATURE_CACHE global _CIA_RUNTIME_SPECIES_ORDER, _CIA_KEPT_PAIR_INDICES_CACHE global _CIA_PAIR_SPECIES_I_CACHE, _CIA_PAIR_SPECIES_J_CACHE global _CIA_RETAINED_SIGMA_CACHE, _CIA_RETAINED_TEMPERATURE_CACHE global _CIA_RETAINED_LOG10_TEMPERATURE_CACHE _CIA_SPECIES_NAMES = () _CIA_SIGMA_CACHE = None _CIA_TEMPERATURE_CACHE = None _CIA_WAVELENGTH_CACHE = None _CIA_LOG10_TEMPERATURE_CACHE = None _CIA_RUNTIME_SPECIES_ORDER = () _CIA_KEPT_PAIR_INDICES_CACHE = None _CIA_PAIR_SPECIES_I_CACHE = None _CIA_PAIR_SPECIES_J_CACHE = None _CIA_RETAINED_SIGMA_CACHE = None _CIA_RETAINED_TEMPERATURE_CACHE = None _CIA_RETAINED_LOG10_TEMPERATURE_CACHE = None _clear_cache()
# Helper function to check if data is in the global cache
[docs] def has_cia_data() -> bool: return _CIA_SIGMA_CACHE is not None
# Load the CIA cross section data from the formatted npz files def _load_cia_npz(index: int, path: str, target_wavelengths: np.ndarray) -> CiaRegistryEntry: # Load the table data = np.load(path, allow_pickle=True) name = data["mol"] if isinstance(name, np.ndarray): name = name.tolist() if not isinstance(name, str): name = str(name) # Get the temperature array, wavenumbers and cross-sections temperatures = np.asarray(data["T"], dtype=float) wn = np.asarray(data["wn"], dtype=float) xs = np.asarray(data["sig"], dtype=float) if not np.all(np.isfinite(xs)): bad = np.where(~np.isfinite(xs)) print(f"[warn] Non-finite CIA xs in {path}: count={bad[0].size}") # Convert to wavelength and inverse array native_wavelengths = 1.0e4 / wn[::-1] native_xs = xs[:, ::-1] target_wavelengths = np.asarray(target_wavelengths, dtype=float) if target_wavelengths.ndim != 1: raise ValueError(f"lam_target must be 1D, got shape {target_wavelengths.shape} for {path}") lam_min, lam_max = float(target_wavelengths[0]), float(target_wavelengths[-1]) wl_min, wl_max = float(native_wavelengths.min()), float(native_wavelengths.max()) if lam_min < wl_min or lam_max > wl_max: print( "[warn] Target wavelength grid " f"[{lam_min:.6g}, {lam_max:.6g}] extends beyond native CIA grid " f"[{wl_min:.6g}, {wl_max:.6g}] in {path}; " "filling out-of-range σ with 1e-199." ) # Interpolate to the master wavelength grid # Use float64 for log10 cross sections to keep dtype consistent. n_temperatures, _ = native_xs.shape wavelength_count = target_wavelengths.size xs_interp = np.empty((n_temperatures, wavelength_count), dtype=np.float64) for idx_temp in range(n_temperatures): xs_interp[idx_temp, :] = np.interp(target_wavelengths, native_wavelengths, native_xs[idx_temp, :], left=-199.0, right=-199.0) xs_interp = np.maximum(xs_interp, -199.0) # Return a CIA table registry entry with NumPy arrays (will be converted to JAX later) # Float64 for grids and cross sections. return CiaRegistryEntry( name=name, idx=index, temperatures=temperatures.astype(np.float64), wavelengths=target_wavelengths.astype(np.float64), cross_sections=xs_interp, ) def _load_cia_zarr(index: int, path: str, target_wavelengths: np.ndarray) -> CiaRegistryEntry: """ Load CIA opacity tables stored in the custom Zarr format. Expected Zarr contents: - attrs["mol"]: species name (string) - T: temperature grid in K (nT,) - wn: wavenumber grid in cm^-1 (nwn,) - sig: log10 cross-sections (nT, nwn) """ if path.endswith(".zip"): from zarr.storage import ZipStore _store = ZipStore(path, mode="r") root = zarr.open_group(store=_store, mode="r") else: root = zarr.open_group(path, mode="r") name = str(root.attrs.get("mol", f"cia_{index}")) temperatures = np.asarray(root["T"][:], dtype=float) wn = np.asarray(root["wn"][:], dtype=float) xs = np.asarray(root["sig"][:], dtype=float) if not np.all(np.isfinite(xs)): bad = np.where(~np.isfinite(xs)) print(f"[warn] Non-finite CIA xs in {path}: count={bad[0].size}") # Convert to wavelength and reverse array order native_wavelengths = 1.0e4 / wn[::-1] native_xs = xs[:, ::-1] target_wavelengths = np.asarray(target_wavelengths, dtype=float) if target_wavelengths.ndim != 1: raise ValueError(f"lam_target must be 1D, got shape {target_wavelengths.shape} for {path}") lam_min, lam_max = float(target_wavelengths[0]), float(target_wavelengths[-1]) wl_min, wl_max = float(native_wavelengths.min()), float(native_wavelengths.max()) if lam_min < wl_min or lam_max > wl_max: print( "[warn] Target wavelength grid " f"[{lam_min:.6g}, {lam_max:.6g}] extends beyond native CIA grid " f"[{wl_min:.6g}, {wl_max:.6g}] in {path}; " "filling out-of-range σ with 1e-199." ) n_temperatures, _ = native_xs.shape wavelength_count = target_wavelengths.size xs_interp = np.empty((n_temperatures, wavelength_count), dtype=np.float64) for idx_temp in range(n_temperatures): xs_interp[idx_temp, :] = np.interp(target_wavelengths, native_wavelengths, native_xs[idx_temp, :], left=-199.0, right=-199.0) xs_interp = np.maximum(xs_interp, -199.0) return CiaRegistryEntry( name=name, idx=index, temperatures=temperatures.astype(np.float64), wavelengths=target_wavelengths.astype(np.float64), cross_sections=xs_interp, ) # Pad the tables to a rectangle (in dimension) - usually only in T as wavelength grids are the same lengths # Uses NumPy for preprocessing (CPU-based padding before sending to device) def _rectangularize_entries(entries: List[CiaRegistryEntry]) -> Tuple[CiaRegistryEntry, ...]: if not entries: return () base_wavelengths = entries[0].wavelengths expected_wavelengths = base_wavelengths.shape[0] for entry in entries[1:]: if entry.wavelengths.shape != base_wavelengths.shape or not np.allclose(entry.wavelengths, base_wavelengths): raise ValueError(f"CIA wavelength grids differ between {entries[0].name} and {entry.name}.") max_temperatures = max(entry.temperatures.shape[0] for entry in entries) padded_entries: List[CiaRegistryEntry] = [] for entry in entries: # Keep as NumPy arrays for preprocessing temperatures = entry.temperatures xs = entry.cross_sections n_temperatures, wavelength_count = xs.shape if wavelength_count != expected_wavelengths: raise ValueError(f"Species {entry.name} has λ grid length {wavelength_count}, expected {expected_wavelengths}.") pad_temperatures = max_temperatures - n_temperatures if pad_temperatures > 0: # Use NumPy padding (CPU-based) temperatures = np.pad(temperatures, (0, pad_temperatures), mode="edge") xs = np.pad(xs, ((0, pad_temperatures), (0, 0)), mode="edge") padded_entries.append( CiaRegistryEntry( name=entry.name, idx=entry.idx, temperatures=temperatures, wavelengths=base_wavelengths, cross_sections=xs, ) ) return tuple(padded_entries) # Load in the CIA table data - add the data to global scope cache files
[docs] def load_cia_registry(cfg, obs, lam_master: Optional[np.ndarray] = None, base_dir: Optional[Path] = None) -> None: # Initialise the global caches global _CIA_SPECIES_NAMES, _CIA_SIGMA_CACHE, _CIA_TEMPERATURE_CACHE, _CIA_WAVELENGTH_CACHE, _CIA_LOG10_TEMPERATURE_CACHE global _CIA_RUNTIME_SPECIES_ORDER, _CIA_KEPT_PAIR_INDICES_CACHE global _CIA_PAIR_SPECIES_I_CACHE, _CIA_PAIR_SPECIES_J_CACHE global _CIA_RETAINED_SIGMA_CACHE, _CIA_RETAINED_TEMPERATURE_CACHE global _CIA_RETAINED_LOG10_TEMPERATURE_CACHE entries: List[CiaRegistryEntry] = [] config = getattr(cfg.opac, "cia", None) if not config: reset_registry() return # Use observational wavelengths if no master wavelength grid is availialbe wavelengths = np.asarray(obs["wl"], dtype=float) if lam_master is None else np.asarray(lam_master, dtype=float) # Read in each CIA table data for index, spec in enumerate(cfg.opac.cia): name = getattr(spec, "species", spec) if name == "H-": print("[warn] cfg.opac.cia includes 'H-': this is no longer treated as a CIA table.") print("[warn] Enable H- continuum under cfg.opac.special instead (bf/ff handled as special opacity).") continue cia_path = Path(spec.path).expanduser() if not cia_path.is_absolute(): if base_dir is not None: cia_path = (Path(base_dir) / cia_path).resolve() else: cia_path = cia_path.resolve() path_str = str(cia_path) print("[CIA] Reading cia xs for", name, "@", path_str) if path_str.endswith(".zarr") or path_str.endswith(".zarr.zip"): if path_str.endswith(".zarr") and not cia_path.exists(): zip_fallback = Path(path_str + ".zip") if zip_fallback.exists(): path_str = str(zip_fallback) print(f"[CIA] .zarr directory not found; using {zip_fallback.name}") entry = _load_cia_zarr(index, path_str, wavelengths) elif path_str.endswith(".npz"): entry = _load_cia_npz(index, path_str, wavelengths) else: raise ValueError(f"Unsupported file format for {path_str}. Expected .npz, .zarr or .zarr.zip") entries.append(entry) # For JAX, need to pad to make the tables rectangular with the same nummber of T grids rectangularized_entries = _rectangularize_entries(entries) if not rectangularized_entries: reset_registry() return # ============================================================================ # CRITICAL: Convert NumPy arrays to JAX arrays here (ONE transfer to device) # ============================================================================ # All preprocessing is done in NumPy (CPU). Now we send the final data # to the device (GPU/CPU as configured) for use in JIT-compiled forward model. # Mixed precision strategy: # Float64 for grids, float32 for cross sections. # ============================================================================ print(f"[CIA] Transferring {len(rectangularized_entries)} species to device...") # Stack cross sections: (n_species, nT, nwl) - already float64 from preprocessing sigma_stacked = np.stack([entry.cross_sections for entry in rectangularized_entries], axis=0) _CIA_SIGMA_CACHE = jnp.asarray(sigma_stacked, dtype=jnp.float32) # Stack temperature grids: (n_species, nT) - keep as float64 for accuracy temp_stacked = np.stack([entry.temperatures for entry in rectangularized_entries], axis=0) _CIA_TEMPERATURE_CACHE = jnp.asarray(temp_stacked, dtype=jnp.float64) _CIA_WAVELENGTH_CACHE = jnp.asarray(rectangularized_entries[0].wavelengths, dtype=jnp.float64) # Pre-compute log10 of temperature grids for efficient interpolation _CIA_LOG10_TEMPERATURE_CACHE = jnp.log10(_CIA_TEMPERATURE_CACHE) print(f"[CIA] Cross section cache: {_CIA_SIGMA_CACHE.shape} (dtype: {_CIA_SIGMA_CACHE.dtype})") print(f"[CIA] Temperature cache: {_CIA_TEMPERATURE_CACHE.shape} (dtype: {_CIA_TEMPERATURE_CACHE.dtype})") print(f"[CIA] Cached log10(T) grids for efficient interpolation") # Estimate memory usage sigma_mb = _CIA_SIGMA_CACHE.size * _CIA_SIGMA_CACHE.itemsize / 1024**2 temp_mb = _CIA_TEMPERATURE_CACHE.size * _CIA_TEMPERATURE_CACHE.itemsize / 1024**2 total_mb = sigma_mb + temp_mb print(f"[CIA] Estimated device memory: {total_mb:.1f} MB (σ: {sigma_mb:.1f} MB, T: {temp_mb:.1f} MB)") # Extract species names (lightweight: just strings) _CIA_SPECIES_NAMES = tuple(entry.name for entry in rectangularized_entries) kept_pair_indices = [i for i, name in enumerate(_CIA_SPECIES_NAMES) if name.strip() != "H-"] _CIA_KEPT_PAIR_INDICES_CACHE = jnp.asarray(kept_pair_indices, dtype=jnp.int32) runtime_species_order: List[str] = [] pair_i: List[int] = [] pair_j: List[int] = [] species_to_idx: dict[str, int] = {} for idx in kept_pair_indices: parts = _CIA_SPECIES_NAMES[idx].strip().split("-") if len(parts) != 2: raise ValueError(f"CIA species '{_CIA_SPECIES_NAMES[idx]}' must be in 'A-B' format") for species in parts: if species not in species_to_idx: species_to_idx[species] = len(runtime_species_order) runtime_species_order.append(species) pair_i.append(species_to_idx[parts[0]]) pair_j.append(species_to_idx[parts[1]]) _CIA_RUNTIME_SPECIES_ORDER = tuple(runtime_species_order) _CIA_PAIR_SPECIES_I_CACHE = jnp.asarray(pair_i, dtype=jnp.int32) _CIA_PAIR_SPECIES_J_CACHE = jnp.asarray(pair_j, dtype=jnp.int32) _CIA_RETAINED_SIGMA_CACHE = _CIA_SIGMA_CACHE[_CIA_KEPT_PAIR_INDICES_CACHE] _CIA_RETAINED_TEMPERATURE_CACHE = _CIA_TEMPERATURE_CACHE[_CIA_KEPT_PAIR_INDICES_CACHE] _CIA_RETAINED_LOG10_TEMPERATURE_CACHE = _CIA_LOG10_TEMPERATURE_CACHE[_CIA_KEPT_PAIR_INDICES_CACHE] # Delete NumPy arrays to free memory (JAX caches now hold the data on device) # This saves ~50 MB for typical CIA tables del rectangularized_entries, entries, sigma_stacked, temp_stacked print(f"[CIA] Freed NumPy temporary arrays from CPU memory") _clear_cache()
### -- lru cached helper functions below --- ###
[docs] @lru_cache(None) def cia_species_names() -> Tuple[str, ...]: if not _CIA_SPECIES_NAMES: raise RuntimeError("CIA registry empty; call build_opacities() first.") return _CIA_SPECIES_NAMES
[docs] @lru_cache(None) def cia_master_wavelength() -> jnp.ndarray: if _CIA_WAVELENGTH_CACHE is None: raise RuntimeError("CIA registry empty; call build_opacities() first.") return _CIA_WAVELENGTH_CACHE
[docs] @lru_cache(None) def cia_sigma_cube() -> jnp.ndarray: if _CIA_SIGMA_CACHE is None: raise RuntimeError("CIA σ cube not built; call build_opacities() first.") return _CIA_SIGMA_CACHE
[docs] @lru_cache(None) def cia_temperature_grids() -> jnp.ndarray: if _CIA_TEMPERATURE_CACHE is None: raise RuntimeError("CIA temperature grids not built; call build_opacities() first.") return _CIA_TEMPERATURE_CACHE
[docs] @lru_cache(None) def cia_temperature_grid() -> jnp.ndarray: return cia_temperature_grids()[0]
[docs] @lru_cache(None) def cia_log10_temperature_grids() -> jnp.ndarray: if _CIA_LOG10_TEMPERATURE_CACHE is None: raise RuntimeError("CIA log10(T) grids not built; call build_opacities() first.") return _CIA_LOG10_TEMPERATURE_CACHE
[docs] @lru_cache(None) def cia_runtime_species_order() -> Tuple[str, ...]: if not _CIA_RUNTIME_SPECIES_ORDER: return () return _CIA_RUNTIME_SPECIES_ORDER
[docs] @lru_cache(None) def cia_kept_pair_indices() -> jnp.ndarray: if _CIA_KEPT_PAIR_INDICES_CACHE is None: raise RuntimeError("CIA kept-pair indices not built; call build_opacities() first.") return _CIA_KEPT_PAIR_INDICES_CACHE
[docs] @lru_cache(None) def cia_pair_species_i() -> jnp.ndarray: if _CIA_PAIR_SPECIES_I_CACHE is None: raise RuntimeError("CIA pair species-i indices not built; call build_opacities() first.") return _CIA_PAIR_SPECIES_I_CACHE
[docs] @lru_cache(None) def cia_pair_species_j() -> jnp.ndarray: if _CIA_PAIR_SPECIES_J_CACHE is None: raise RuntimeError("CIA pair species-j indices not built; call build_opacities() first.") return _CIA_PAIR_SPECIES_J_CACHE
[docs] @lru_cache(None) def cia_retained_sigma_cube() -> jnp.ndarray: if _CIA_RETAINED_SIGMA_CACHE is None: raise RuntimeError("CIA retained sigma cube not built; call build_opacities() first.") return _CIA_RETAINED_SIGMA_CACHE
[docs] @lru_cache(None) def cia_retained_temperature_grids() -> jnp.ndarray: if _CIA_RETAINED_TEMPERATURE_CACHE is None: raise RuntimeError("CIA retained temperature grids not built; call build_opacities() first.") return _CIA_RETAINED_TEMPERATURE_CACHE
[docs] @lru_cache(None) def cia_retained_log10_temperature_grids() -> jnp.ndarray: if _CIA_RETAINED_LOG10_TEMPERATURE_CACHE is None: raise RuntimeError("CIA retained log10(T) grids not built; call build_opacities() first.") return _CIA_RETAINED_LOG10_TEMPERATURE_CACHE