"""
registry_ck.py
==============
"""
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Tuple, Optional
from pathlib import Path
import jax.numpy as jnp
import numpy as np
import h5py
import zarr
__all__ = [
"CKRegistryEntry",
"reset_registry",
"has_ck_data",
"load_ck_registry",
"ck_species_names",
"ck_master_wavelength",
"ck_pressure_grid",
"ck_temperature_grid",
"ck_temperature_grids",
"ck_sigma_cube",
"ck_g_points",
"ck_g_weights",
"ck_log10_pressure_grid",
"ck_log10_temperature_grids",
"ck_runtime_species_order",
"ck_g_points_1d",
"ck_g_weights_1d",
]
# Dataclass for each of the correlated-k opacity tables
# Note: During preprocessing, all arrays are NumPy (CPU)
# They get converted to JAX (device) only at the final cache creation step
# Mixed precision: float64 for grids (better accuracy), float32 for cross sections (memory savings)
[docs]
@dataclass(frozen=True)
class CKRegistryEntry:
name: str
idx: int
pressures: np.ndarray # NumPy during preprocessing (float64) - (n_pressure,)
temperatures: np.ndarray # NumPy during preprocessing (float64) - (n_temperature,)
wavelengths: np.ndarray # NumPy during preprocessing (float64) - (n_wavelength,)
g_points: np.ndarray # NumPy during preprocessing (float64) - (n_g,)
g_weights: np.ndarray # NumPy during preprocessing (float64) - (n_g,) - quadrature weights
cross_sections: np.ndarray # NumPy during preprocessing (float32 to save memory) - (n_temperature, n_pressure, n_wavelength, n_g)
# Global registries and caches for forward model
_CK_SPECIES_NAMES: Tuple[str, ...] = () # Lightweight: only species names (few bytes)
_CK_SIGMA_CACHE: jnp.ndarray | None = None
_CK_TEMPERATURE_CACHE: jnp.ndarray | None = None
_CK_G_POINTS_CACHE: jnp.ndarray | None = None
_CK_G_WEIGHTS_CACHE: jnp.ndarray | None = None
_CK_WAVELENGTH_CACHE: jnp.ndarray | None = None
_CK_PRESSURE_CACHE: jnp.ndarray | None = None
_CK_LOG10_TEMPERATURE_CACHE: jnp.ndarray | None = None
_CK_LOG10_PRESSURE_CACHE: jnp.ndarray | None = None
_CK_RUNTIME_SPECIES_ORDER: Tuple[str, ...] = ()
_CK_G_POINTS_1D_CACHE: jnp.ndarray | None = None
_CK_G_WEIGHTS_1D_CACHE: jnp.ndarray | None = None
# Clear all the cache entries
def _clear_cache():
ck_species_names.cache_clear()
ck_runtime_species_order.cache_clear()
ck_master_wavelength.cache_clear()
ck_pressure_grid.cache_clear()
ck_temperature_grid.cache_clear()
ck_temperature_grids.cache_clear()
ck_sigma_cube.cache_clear()
ck_g_points.cache_clear()
ck_g_weights.cache_clear()
ck_g_points_1d.cache_clear()
ck_g_weights_1d.cache_clear()
ck_log10_pressure_grid.cache_clear()
ck_log10_temperature_grids.cache_clear()
# Reset all the global registries
[docs]
def reset_registry() -> None:
global _CK_SPECIES_NAMES, _CK_SIGMA_CACHE, _CK_TEMPERATURE_CACHE, _CK_G_POINTS_CACHE, _CK_G_WEIGHTS_CACHE
global _CK_WAVELENGTH_CACHE, _CK_PRESSURE_CACHE, _CK_LOG10_TEMPERATURE_CACHE, _CK_LOG10_PRESSURE_CACHE
global _CK_RUNTIME_SPECIES_ORDER, _CK_G_POINTS_1D_CACHE, _CK_G_WEIGHTS_1D_CACHE
_CK_SPECIES_NAMES = ()
_CK_SIGMA_CACHE = None
_CK_TEMPERATURE_CACHE = None
_CK_G_POINTS_CACHE = None
_CK_G_WEIGHTS_CACHE = None
_CK_WAVELENGTH_CACHE = None
_CK_PRESSURE_CACHE = None
_CK_LOG10_TEMPERATURE_CACHE = None
_CK_LOG10_PRESSURE_CACHE = None
_CK_RUNTIME_SPECIES_ORDER = ()
_CK_G_POINTS_1D_CACHE = None
_CK_G_WEIGHTS_1D_CACHE = None
_clear_cache()
# Check if the registries are set or not
[docs]
def has_ck_data() -> bool:
return _CK_SIGMA_CACHE is not None
# Function to load petitRADTRANS HDF5 correlated-k opacity data
def _load_ck_h5(index: int, spec, path: str, obs: dict, use_full_grid: bool = False) -> CKRegistryEntry:
"""
Load petitRADTRANS HDF5 format correlated-k opacity tables.
petitRADTRANS format:
- mol_name or derive from DOI: molecule name (string)
- p: pressure grid in bar (nP,)
- t: temperature grid in K (nT,)
- bin_centers: wavenumber bin centers in cm^-1 (nwl,)
- kcoeff: correlated-k coefficients in cm^2/molecule (nP, nT, nwl, ng)
- ngauss: number of gauss points (scalar)
- weights: gauss quadrature weights (ng,)
- samples or derive g-points: g-point locations (ng,)
Returns data in registry format:
- pressures in bar (nP,)
- temperatures in K (nT,)
- wavelengths in microns (cut to obs bands)
- g_points: g-point locations (ng,)
- g_weights: gauss quadrature weights (ng,)
- cross_sections in log10(cm^2) (nT, nP, nwl_cut, ng)
Note: Correlated-k tables are pre-banded and cannot be interpolated in wavelength.
This function cuts the table to only wavelengths within observation bands.
"""
name = getattr(spec, "species", f"ck_{index}")
with h5py.File(path, 'r') as f:
# Read grids
pressures = np.asarray(f['p'][:], dtype=float) # bar, shape (nP,)
temperatures = np.asarray(f['t'][:], dtype=float) # K, shape (nT,)
bin_centers_wn = np.asarray(f['bin_centers'][:], dtype=float) # cm^-1, shape (nwl,)
native_kcoeff = np.asarray(f['kcoeff'][:], dtype=float) # cm^2/molecule, shape (nP, nT, nwl, ng)
# Read gauss quadrature information
ngauss_dataset = f['ngauss']
if ngauss_dataset.shape == (): # Scalar dataset
ngauss = int(ngauss_dataset[()])
else:
ngauss = int(ngauss_dataset[:])
weights = np.asarray(f['weights'][:], dtype=float) # shape (ng,)
# Get g-points - either from 'samples' dataset or create uniform grid
if 'samples' in f:
g_points = np.asarray(f['samples'][:], dtype=float)
else:
# Create uniform g-point grid from 0 to 1
g_points = np.linspace(0.0, 1.0, ngauss)
if g_points.shape[0] != weights.shape[0]:
raise ValueError(
f"Invalid petitRADTRANS k-table {path}: g_points has length {g_points.shape[0]} "
f"but weights has length {weights.shape[0]}."
)
# Ensure g-points are strictly increasing and weights are aligned with the kcoeff g-axis.
# Some table formats store g-points unsorted.
if g_points.ndim != 1 or weights.ndim != 1:
raise ValueError(f"Invalid petitRADTRANS k-table {path}: expected 1D g arrays.")
g_sort = np.argsort(g_points)
g_points = g_points[g_sort]
weights = weights[g_sort]
native_kcoeff = native_kcoeff[..., g_sort]
if np.any(~np.isfinite(g_points)) or np.any(~np.isfinite(weights)):
raise ValueError(f"Invalid petitRADTRANS k-table {path}: non-finite g-points or weights.")
if np.any(np.diff(g_points) <= 0.0):
raise ValueError(f"Invalid petitRADTRANS k-table {path}: g-points are not strictly increasing.")
if g_points.min() < 0.0 or g_points.max() > 1.0:
raise ValueError(f"Invalid petitRADTRANS k-table {path}: g-points must lie within [0, 1].")
# Normalize quadrature weights (RT assumes sum(weights)=1).
weights = np.clip(weights, 0.0, None)
wsum = float(np.sum(weights))
if not np.isfinite(wsum) or wsum <= 0.0:
raise ValueError(f"Invalid petitRADTRANS k-table {path}: non-positive weight sum {wsum}.")
weights = weights / wsum
# Convert wavenumber bin centers to wavelength in microns
# λ[μm] = 10000 / ν[cm^-1]
wavelengths = 10000.0 / bin_centers_wn # μm
# Sort wavelengths (wavenumbers are typically descending, so wavelengths will be ascending)
sort_idx = np.argsort(wavelengths)
wavelengths = wavelengths[sort_idx]
# Transpose from (nP, nT, nwl, ng) to (nT, nP, nwl, ng) and apply wavelength sorting
kcoeff_transposed = np.transpose(native_kcoeff, (1, 0, 2, 3))[:, :, sort_idx, :]
# Create mask for wavelengths within observation bands
if use_full_grid:
mask = np.ones_like(wavelengths, dtype=bool)
print(f"[c-k] Using full wavelength grid for {name}: {len(wavelengths)} bins")
else:
wl_obs = np.asarray(obs["wl"], dtype=float)
dwl_obs = np.asarray(obs["dwl"], dtype=float)
left_edges = wl_obs - dwl_obs
right_edges = wl_obs + dwl_obs
# Mask wavelengths that fall within any observation bin
mask = np.any(
(wavelengths[None, :] >= left_edges[:, None]) & (wavelengths[None, :] <= right_edges[:, None]),
axis=0,
)
if not np.any(mask):
raise ValueError(f"No CK wavelengths for {name} lie within observation bins.")
print(f"[c-k] Cut wavelength grid for {name}: {np.sum(mask)}/{len(wavelengths)} bins retained")
# Apply mask to wavelengths and cross_sections
wavelengths_cut = wavelengths[mask]
kcoeff_cut = kcoeff_transposed[:, :, mask, :]
# Convert to log10 (handle zeros by setting minimum value)
# Use float32 for log10 cross sections to save memory.
min_xs = 1e-99 # corresponds to log10 = -99
kcoeff_log = np.log10(np.maximum(kcoeff_cut, min_xs)).astype(np.float32)
# Return a dataclass with NumPy arrays (will be converted to JAX later)
# Mixed precision: float64 for grids, float32 for cross sections to save memory.
return CKRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=wavelengths_cut.astype(np.float64),
g_points=g_points.astype(np.float64),
g_weights=weights.astype(np.float64),
cross_sections=kcoeff_log,
)
def _load_ck_zarr(index: int, spec, path: str, obs: dict, use_full_grid: bool = False) -> CKRegistryEntry:
"""
Load correlated-k opacity tables stored in the custom Zarr format.
Expected Zarr contents:
- attrs["molecule"]: species name (string; optional)
- pressure: pressure grid in bar (nP,)
- temperature: temperature grid in K (nT,)
- wavelength: wavelength grid in microns (nwl,)
- g_points: g-point locations (ng,)
- g_weights: quadrature weights (ng,)
- cross_section: log10 cross-sections (nT, nP, nwl, ng)
"""
if path.endswith(".zip"):
from zarr.storage import ZipStore
_store = ZipStore(path, mode="r")
root = zarr.open_group(store=_store, mode="r")
else:
root = zarr.open_group(path, mode="r")
name_raw = root.attrs.get("molecule", getattr(spec, "species", f"ck_{index}"))
name = str(name_raw)
pressures = np.asarray(root["pressure"][:], dtype=float)
temperatures = np.asarray(root["temperature"][:], dtype=float)
wavelengths = np.asarray(root["wavelength"][:], dtype=float)
cross_section = np.asarray(root["cross_section"][:], dtype=float)
nG = cross_section.shape[-1]
g_points = np.asarray(root.get("g_points", np.linspace(0.0, 1.0, nG)), dtype=float)
default_weights = np.ones_like(g_points) / g_points.size if g_points.size > 0 else g_points
g_weights = np.asarray(root.get("g_weights", default_weights), dtype=float)
if cross_section.ndim != 4:
raise ValueError(f"Invalid cross_section shape {cross_section.shape} in {path}; expected 4D array.")
nT, nP, nW, nG = cross_section.shape
if wavelengths.size != nW:
raise ValueError(f"Wavelength grid length {wavelengths.size} does not match cross-section axis {nW} in {path}.")
if g_points.size != nG:
raise ValueError(f"g_point array length {g_points.size} does not match cross-section axis {nG} in {path}.")
if g_weights.size != nG:
raise ValueError(f"g_weight array length {g_weights.size} does not match cross-section axis {nG} in {path}.")
g_sort = np.argsort(g_points)
g_points = g_points[g_sort]
g_weights = g_weights[g_sort]
cross_section = cross_section[..., g_sort]
if np.any(~np.isfinite(g_points)) or np.any(~np.isfinite(g_weights)):
raise ValueError(f"Invalid c-k table {path}: non-finite g-points or weights.")
if np.any(np.diff(g_points) <= 0.0):
raise ValueError(f"Invalid c-k table {path}: g-points are not strictly increasing.")
if g_points.min() < 0.0 or g_points.max() > 1.0:
raise ValueError(f"Invalid c-k table {path}: g-points must lie within [0, 1].")
g_weights = np.clip(g_weights, 0.0, None)
wsum = float(np.sum(g_weights))
if not np.isfinite(wsum) or wsum <= 0.0:
raise ValueError(f"Invalid c-k table {path}: non-positive weight sum {wsum}.")
g_weights = g_weights / wsum
if not use_full_grid:
wl_obs = np.asarray(obs["wl"], dtype=float)
dwl_obs = np.asarray(obs["dwl"], dtype=float)
left_edges = wl_obs - dwl_obs
right_edges = wl_obs + dwl_obs
mask = np.any(
(wavelengths[None, :] >= left_edges[:, None]) & (wavelengths[None, :] <= right_edges[:, None]),
axis=0,
)
if not np.any(mask):
raise ValueError(f"No CK wavelengths for {name} lie within observation bins.")
wavelengths = wavelengths[mask]
cross_section = cross_section[:, :, mask, :]
else:
print(f"[c-k] Using full wavelength grid for {name}: {wavelengths.size} bins")
return CKRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=wavelengths.astype(np.float64),
g_points=g_points.astype(np.float64),
g_weights=g_weights.astype(np.float64),
cross_sections=cross_section.astype(np.float32),
)
def _load_ck_npz(index: int, spec, path: str, obs: dict, use_full_grid: bool = False) -> CKRegistryEntry:
"""
Load correlated-k opacity tables stored in the custom NPZ format.
Expected NPZ contents:
- molecule: species name (string or bytes; optional)
- pressure: pressure grid in bar (nP,)
- temperature: temperature grid in K (nT,)
- wavelength: wavelength grid in microns (nwl,)
- g_points: g-point locations (ng,)
- g_weights: quadrature weights (ng,)
- cross_section: log10 cross-sections (nT, nP, nwl, ng)
"""
with np.load(path) as data:
cross_section = np.asarray(data["cross_section"], dtype=float)
name_raw = data.get("molecule", getattr(spec, "species", f"ck_{index}"))
pressures = np.asarray(data["pressure"], dtype=float)
temperatures = np.asarray(data["temperature"], dtype=float)
wavelengths = np.asarray(data["wavelength"], dtype=float)
nG = cross_section.shape[-1]
g_points = np.asarray(data.get("g_points", np.linspace(0.0, 1.0, nG)), dtype=float)
default_weights = np.ones_like(g_points) / g_points.size if g_points.size > 0 else g_points
g_weights = np.asarray(data.get("g_weights", default_weights), dtype=float)
if isinstance(name_raw, np.ndarray):
name = name_raw.tolist()
if isinstance(name, list):
name = name[0]
else:
name = name_raw
if isinstance(name, bytes):
name = name.decode("utf-8")
name = str(name)
if cross_section.ndim != 4:
raise ValueError(f"Invalid cross_section shape {cross_section.shape} in {path}; expected 4D array.")
nT, nP, nW, nG = cross_section.shape
if wavelengths.size != nW:
raise ValueError(f"Wavelength grid length {wavelengths.size} does not match cross-section axis {nW} in {path}.")
if g_points.size != nG:
raise ValueError(f"g_point array length {g_points.size} does not match cross-section axis {nG} in {path}.")
if g_weights.size != nG:
raise ValueError(f"g_weight array length {g_weights.size} does not match cross-section axis {nG} in {path}.")
# Ensure g-points are strictly increasing and weights are aligned with the cross_section g-axis.
g_sort = np.argsort(g_points)
g_points = g_points[g_sort]
g_weights = g_weights[g_sort]
cross_section = cross_section[..., g_sort]
if np.any(~np.isfinite(g_points)) or np.any(~np.isfinite(g_weights)):
raise ValueError(f"Invalid c-k table {path}: non-finite g-points or weights.")
if np.any(np.diff(g_points) <= 0.0):
raise ValueError(f"Invalid c-k table {path}: g-points are not strictly increasing.")
if g_points.min() < 0.0 or g_points.max() > 1.0:
raise ValueError(f"Invalid c-k table {path}: g-points must lie within [0, 1].")
# Normalize quadrature weights (RT assumes sum(weights)=1).
g_weights = np.clip(g_weights, 0.0, None)
wsum = float(np.sum(g_weights))
if not np.isfinite(wsum) or wsum <= 0.0:
raise ValueError(f"Invalid c-k table {path}: non-positive weight sum {wsum}.")
g_weights = g_weights / wsum
if not use_full_grid:
wl_obs = np.asarray(obs["wl"], dtype=float)
dwl_obs = np.asarray(obs["dwl"], dtype=float)
left_edges = wl_obs - dwl_obs
right_edges = wl_obs + dwl_obs
mask = np.any(
(wavelengths[None, :] >= left_edges[:, None]) & (wavelengths[None, :] <= right_edges[:, None]),
axis=0,
)
if not np.any(mask):
raise ValueError(f"No CK wavelengths for {name} lie within observation bins.")
wavelengths = wavelengths[mask]
cross_section = cross_section[:, :, mask, :]
else:
print(f"[c-k] Using full wavelength grid for {name}: {wavelengths.size} bins")
# Return a dataclass with NumPy arrays (will be converted to JAX later)
# Mixed precision: float64 for grids, float32 for cross sections to save memory.
return CKRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=wavelengths.astype(np.float64),
g_points=g_points.astype(np.float64),
g_weights=g_weights.astype(np.float64),
cross_sections=cross_section.astype(np.float32),
)
# Pad the tables to a rectangle (in dimension) - usually only in T and g as wavelength and pressure grids are the same
# Uses NumPy for preprocessing (CPU-based padding before sending to device)
def _rectangularize_entries(entries: List[CKRegistryEntry]) -> Tuple[CKRegistryEntry, ...]:
# Return if zero c-k table
if not entries:
return ()
# Find the wavelength and pressure grid from the first tables (should be the same across all species)
base_wavelengths = entries[0].wavelengths
base_pressures = entries[0].pressures
base_g_points = entries[0].g_points
base_g_weights = entries[0].g_weights
expected_wavelengths = base_wavelengths.shape[0]
for entry in entries[1:]:
if entry.wavelengths.shape != base_wavelengths.shape or not np.allclose(entry.wavelengths, base_wavelengths):
raise ValueError(f"c-k opacity wavelength grids differ between {entries[0].name} and {entry.name}.")
if entry.pressures.shape != base_pressures.shape or not np.allclose(entry.pressures, base_pressures):
raise ValueError(f"c-k opacity pressure grids differ between {entries[0].name} and {entry.name}.")
if entry.g_points.shape != base_g_points.shape or not np.allclose(entry.g_points, base_g_points):
raise ValueError(f"c-k g-point grids differ between {entries[0].name} and {entry.name}.")
if entry.g_weights.shape != base_g_weights.shape or not np.allclose(entry.g_weights, base_g_weights):
raise ValueError(f"c-k g-weight grids differ between {entries[0].name} and {entry.name}.")
# Find the max number of pressure points
max_pressures = max(entry.pressures.shape[0] for entry in entries)
# Find the max number of temperature points
max_temperatures = max(entry.temperatures.shape[0] for entry in entries)
# Find the max number of g-points
max_g = max(entry.g_points.shape[0] for entry in entries)
# Start a new list for the padded cross section tables and pad the temperature and g arrays
padded_entries: List[CKRegistryEntry] = []
for entry in entries:
# Keep as NumPy arrays for preprocessing
pressures = entry.pressures
temperatures = entry.temperatures
g_points = entry.g_points
g_weights = entry.g_weights
xs = entry.cross_sections
current_temperatures, current_pressures, wavelength_count, current_g = xs.shape
if wavelength_count != expected_wavelengths:
raise ValueError(f"Species {entry.name} has λ grid length {wavelength_count}, expected {expected_wavelengths}.")
if current_pressures != max_pressures:
raise ValueError(f"Species {entry.name} has nP={current_pressures}, expected {max_pressures} for common grid.")
# Pad temperatures (use NumPy padding)
pad_temperatures = max_temperatures - current_temperatures
if pad_temperatures > 0:
temperatures = np.pad(temperatures, (0, pad_temperatures), mode="edge")
xs = np.pad(xs, ((0, pad_temperatures), (0, 0), (0, 0), (0, 0)), mode="edge")
# Pad g-points and g-weights (use NumPy padding)
pad_g = max_g - current_g
if pad_g > 0:
g_points = np.pad(g_points, (0, pad_g), mode="edge")
g_weights = np.pad(g_weights, (0, pad_g), constant_values=0.0) # Pad weights with 0
xs = np.pad(xs, ((0, 0), (0, 0), (0, 0), (0, pad_g)), mode="edge")
padded_entries.append(
CKRegistryEntry(
name=entry.name,
idx=entry.idx,
pressures=pressures,
temperatures=temperatures,
wavelengths=base_wavelengths,
g_points=g_points,
g_weights=g_weights,
cross_sections=xs,
)
)
return tuple(padded_entries)
# Read in and prepare the correlated-k data
[docs]
def load_ck_registry(cfg, obs, lam_master: Optional[np.ndarray] = None, base_dir: Optional[Path] = None):
# Allocate the global scope caches
global _CK_SPECIES_NAMES, _CK_SIGMA_CACHE, _CK_TEMPERATURE_CACHE, _CK_G_POINTS_CACHE, _CK_G_WEIGHTS_CACHE
global _CK_WAVELENGTH_CACHE, _CK_PRESSURE_CACHE, _CK_LOG10_TEMPERATURE_CACHE, _CK_LOG10_PRESSURE_CACHE
global _CK_RUNTIME_SPECIES_ORDER, _CK_G_POINTS_1D_CACHE, _CK_G_WEIGHTS_1D_CACHE
entries: List[CKRegistryEntry] = []
# When cfg.opac.ck is True (boolean), species are listed in cfg.opac.line
# When cfg.opac.ck is a list, it contains the species directly
ck_mode = getattr(cfg.opac, "ck", None)
if not ck_mode:
reset_registry()
return
# Get species list: if ck is True/False, use cfg.opac.line; otherwise use ck itself
if isinstance(ck_mode, bool):
config = getattr(cfg.opac, "line", None)
else:
config = ck_mode
if not config or config in ("None", "none"):
reset_registry()
return
# Check if using full grid (from cfg.opac.full_grid)
use_full_grid = getattr(cfg.opac, "full_grid", False)
# Read in the c-k data for each species given by the YAML file - add to the entries list
for index, spec in enumerate(config):
path = Path(spec.path).expanduser()
if not path.is_absolute():
if base_dir is not None:
path = (Path(base_dir) / path).resolve()
else:
path = path.resolve()
path_str = str(path)
print("[c-k] Reading correlated-k xs for", spec.species, "@", path_str)
# Check file format
if path_str.endswith(".npz"):
entry = _load_ck_npz(index, spec, path_str, obs, use_full_grid=use_full_grid)
elif path_str.endswith('.h5') or path_str.endswith('.hdf5'):
entry = _load_ck_h5(index, spec, path_str, obs, use_full_grid=use_full_grid)
elif path_str.endswith('.zarr') or path_str.endswith('.zarr.zip'):
if path_str.endswith('.zarr') and not path.exists():
zip_fallback = Path(path_str + '.zip')
if zip_fallback.exists():
path_str = str(zip_fallback)
print(f"[c-k] .zarr directory not found; using {zip_fallback.name}")
entry = _load_ck_zarr(index, spec, path_str, obs, use_full_grid=use_full_grid)
else:
raise ValueError(f"Unsupported file format for {path_str}. Expected .npz, .h5, .hdf5, .zarr or .zarr.zip")
entries.append(entry)
# Now need to pad in the temperature and g dimensions to make all grids to the same size (for JAX)
rectangularized_entries = _rectangularize_entries(entries)
if not rectangularized_entries:
reset_registry()
return
# ============================================================================
# CRITICAL: Convert NumPy arrays to JAX arrays here (ONE transfer to device)
# ============================================================================
# All preprocessing is done in NumPy (CPU). Now we send the final data
# to the device (GPU/CPU as configured) for use in JIT-compiled forward model.
# Mixed precision strategy:
# - float64 for grids (pressures, temperatures, wavelengths, g-points, g-weights) → better interpolation accuracy
# - float32 for cross sections → halves memory usage (especially important with extra g dimension)
# ============================================================================
print(f"[CK] Transferring {len(rectangularized_entries)} species to device...")
# Stack cross sections: (n_species, nT, nP, nwl, ng) - already float32 from preprocessing
sigma_stacked = np.stack([entry.cross_sections for entry in rectangularized_entries], axis=0)
_CK_SIGMA_CACHE = jnp.asarray(sigma_stacked, dtype=jnp.float32)
# Stack temperature grids: (n_species, nT) - keep as float64 for accuracy
temp_stacked = np.stack([entry.temperatures for entry in rectangularized_entries], axis=0)
_CK_TEMPERATURE_CACHE = jnp.asarray(temp_stacked, dtype=jnp.float64)
# Stack g-points: (n_species, ng) - keep as float64 for accuracy
g_points_stacked = np.stack([entry.g_points for entry in rectangularized_entries], axis=0)
_CK_G_POINTS_CACHE = jnp.asarray(g_points_stacked, dtype=jnp.float64)
# Stack g-weights: (n_species, ng) - keep as float64 for accuracy
g_weights_stacked = np.stack([entry.g_weights for entry in rectangularized_entries], axis=0)
_CK_G_WEIGHTS_CACHE = jnp.asarray(g_weights_stacked, dtype=jnp.float64)
_CK_WAVELENGTH_CACHE = jnp.asarray(rectangularized_entries[0].wavelengths, dtype=jnp.float64)
_CK_PRESSURE_CACHE = jnp.asarray(rectangularized_entries[0].pressures, dtype=jnp.float64)
# Pre-compute log10 of grids for efficient interpolation
_CK_LOG10_PRESSURE_CACHE = jnp.log10(_CK_PRESSURE_CACHE)
_CK_LOG10_TEMPERATURE_CACHE = jnp.log10(_CK_TEMPERATURE_CACHE)
print(f"[CK] Cross section cache: {_CK_SIGMA_CACHE.shape} (dtype: {_CK_SIGMA_CACHE.dtype})")
print(f"[CK] Temperature cache: {_CK_TEMPERATURE_CACHE.shape} (dtype: {_CK_TEMPERATURE_CACHE.dtype})")
print(f"[CK] G-points cache: {_CK_G_POINTS_CACHE.shape} (dtype: {_CK_G_POINTS_CACHE.dtype})")
print(f"[CK] G-weights cache: {_CK_G_WEIGHTS_CACHE.shape} (dtype: {_CK_G_WEIGHTS_CACHE.dtype})")
print(f"[CK] Cached log10(P) and log10(T) grids for efficient interpolation")
# Estimate memory usage
sigma_mb = _CK_SIGMA_CACHE.size * _CK_SIGMA_CACHE.itemsize / 1024**2
temp_mb = _CK_TEMPERATURE_CACHE.size * _CK_TEMPERATURE_CACHE.itemsize / 1024**2
g_points_mb = _CK_G_POINTS_CACHE.size * _CK_G_POINTS_CACHE.itemsize / 1024**2
g_weights_mb = _CK_G_WEIGHTS_CACHE.size * _CK_G_WEIGHTS_CACHE.itemsize / 1024**2
total_mb = sigma_mb + temp_mb + g_points_mb + g_weights_mb
print(f"[CK] Estimated device memory: {total_mb:.1f} MB (σ: {sigma_mb:.1f} MB, T: {temp_mb:.2f} MB, g: {g_points_mb:.2f} MB, w: {g_weights_mb:.2f} MB)")
# Extract species names (lightweight: just strings)
_CK_SPECIES_NAMES = tuple(entry.name for entry in rectangularized_entries)
_CK_RUNTIME_SPECIES_ORDER = _CK_SPECIES_NAMES
_CK_G_POINTS_1D_CACHE = _CK_G_POINTS_CACHE[0] if _CK_G_POINTS_CACHE.ndim > 1 else _CK_G_POINTS_CACHE
_CK_G_WEIGHTS_1D_CACHE = _CK_G_WEIGHTS_CACHE[0] if _CK_G_WEIGHTS_CACHE.ndim > 1 else _CK_G_WEIGHTS_CACHE
# Delete NumPy arrays to free memory (JAX caches now hold the data on device)
# This saves ~500+ MB for typical CK tables (biggest memory savings!)
del rectangularized_entries, entries, sigma_stacked, temp_stacked, g_points_stacked, g_weights_stacked
print(f"[CK] Freed NumPy temporary arrays from CPU memory")
_clear_cache()
### -- lru cached helper functions below --- ###
[docs]
@lru_cache(None)
def ck_species_names() -> Tuple[str, ...]:
if not _CK_SPECIES_NAMES:
raise RuntimeError("CK registry empty; call build_opacities() first.")
return _CK_SPECIES_NAMES
[docs]
@lru_cache(None)
def ck_runtime_species_order() -> Tuple[str, ...]:
if not _CK_RUNTIME_SPECIES_ORDER:
raise RuntimeError("CK runtime species order not built; call build_opacities() first.")
return _CK_RUNTIME_SPECIES_ORDER
[docs]
@lru_cache(None)
def ck_master_wavelength() -> jnp.ndarray:
if _CK_WAVELENGTH_CACHE is None:
raise RuntimeError("CK registry empty; call build_opacities() first.")
return _CK_WAVELENGTH_CACHE
[docs]
@lru_cache(None)
def ck_pressure_grid() -> jnp.ndarray:
if _CK_PRESSURE_CACHE is None:
raise RuntimeError("CK registry empty; call build_opacities() first.")
return _CK_PRESSURE_CACHE
[docs]
@lru_cache(None)
def ck_temperature_grids() -> jnp.ndarray:
if _CK_TEMPERATURE_CACHE is None:
raise RuntimeError("c-k temperature grids not built; call build_opacities() first.")
return _CK_TEMPERATURE_CACHE
[docs]
@lru_cache(None)
def ck_temperature_grid() -> jnp.ndarray:
return ck_temperature_grids()[0]
[docs]
@lru_cache(None)
def ck_sigma_cube() -> jnp.ndarray:
if _CK_SIGMA_CACHE is None:
raise RuntimeError("c-k σ cube not built; call build_opacities() first.")
return _CK_SIGMA_CACHE
[docs]
@lru_cache(None)
def ck_g_points() -> jnp.ndarray:
if _CK_G_POINTS_CACHE is None:
raise RuntimeError("c-k g-points not built; call build_opacities() first.")
return _CK_G_POINTS_CACHE
[docs]
@lru_cache(None)
def ck_g_weights() -> jnp.ndarray:
if _CK_G_WEIGHTS_CACHE is None:
raise RuntimeError("c-k g-weights not built; call build_opacities() first.")
return _CK_G_WEIGHTS_CACHE
[docs]
@lru_cache(None)
def ck_log10_pressure_grid() -> jnp.ndarray:
if _CK_LOG10_PRESSURE_CACHE is None:
raise RuntimeError("c-k log10(P) grid not built; call build_opacities() first.")
return _CK_LOG10_PRESSURE_CACHE
[docs]
@lru_cache(None)
def ck_log10_temperature_grids() -> jnp.ndarray:
if _CK_LOG10_TEMPERATURE_CACHE is None:
raise RuntimeError("c-k log10(T) grids not built; call build_opacities() first.")
return _CK_LOG10_TEMPERATURE_CACHE
[docs]
@lru_cache(None)
def ck_g_points_1d() -> jnp.ndarray:
if _CK_G_POINTS_1D_CACHE is None:
raise RuntimeError("c-k 1D g-points cache not built; call build_opacities() first.")
return _CK_G_POINTS_1D_CACHE
[docs]
@lru_cache(None)
def ck_g_weights_1d() -> jnp.ndarray:
if _CK_G_WEIGHTS_1D_CACHE is None:
raise RuntimeError("c-k 1D g-weights cache not built; call build_opacities() first.")
return _CK_G_WEIGHTS_1D_CACHE