"""
registry_line.py
================
"""
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Tuple, Optional
from pathlib import Path
import jax.numpy as jnp
import numpy as np
import h5py
import zarr
__all__ = [
"LineRegistryEntry",
"reset_registry",
"has_line_data",
"load_line_registry",
"line_species_names",
"line_master_wavelength",
"line_pressure_grid",
"line_temperature_grid",
"line_temperature_grids",
"line_sigma_cube",
"line_log10_pressure_grid",
"line_log10_temperature_grids",
"line_runtime_species_order",
]
# Dataclass for each of the line opacity tables
# Note: During preprocessing, all arrays are NumPy (CPU)
# They get converted to JAX (device) only at the final cache creation step
# Mixed precision: float64 for grids (better accuracy), float32 for cross sections (memory savings)
[docs]
@dataclass(frozen=True)
class LineRegistryEntry:
name: str
idx: int
pressures: np.ndarray # NumPy during preprocessing (float64)
temperatures: np.ndarray # NumPy during preprocessing (float64)
wavelengths: np.ndarray # NumPy during preprocessing (float64)
cross_sections: np.ndarray # NumPy during preprocessing (float32 to save memory)
# Global registries and caches for forward model
_LINE_SPECIES_NAMES: Tuple[str, ...] = () # Lightweight: only species names (few bytes)
_LINE_SIGMA_CACHE: jnp.ndarray | None = None
_LINE_TEMPERATURE_CACHE: jnp.ndarray | None = None
_LINE_WAVELENGTH_CACHE: jnp.ndarray | None = None
_LINE_PRESSURE_CACHE: jnp.ndarray | None = None
_LINE_LOG10_TEMPERATURE_CACHE: jnp.ndarray | None = None
_LINE_LOG10_PRESSURE_CACHE: jnp.ndarray | None = None
# Clear all the cache entries
def _clear_cache():
line_species_names.cache_clear()
line_runtime_species_order.cache_clear()
line_master_wavelength.cache_clear()
line_pressure_grid.cache_clear()
line_temperature_grid.cache_clear()
line_temperature_grids.cache_clear()
line_sigma_cube.cache_clear()
line_log10_pressure_grid.cache_clear()
line_log10_temperature_grids.cache_clear()
# Reset all the global registries
[docs]
def reset_registry() -> None:
global _LINE_SPECIES_NAMES, _LINE_SIGMA_CACHE, _LINE_TEMPERATURE_CACHE, _LINE_WAVELENGTH_CACHE, _LINE_PRESSURE_CACHE
global _LINE_LOG10_TEMPERATURE_CACHE, _LINE_LOG10_PRESSURE_CACHE
_LINE_SPECIES_NAMES = ()
_LINE_SIGMA_CACHE = None
_LINE_TEMPERATURE_CACHE = None
_LINE_WAVELENGTH_CACHE = None
_LINE_PRESSURE_CACHE = None
_LINE_LOG10_TEMPERATURE_CACHE = None
_LINE_LOG10_PRESSURE_CACHE = None
_clear_cache()
# Check if the registries are set or not
[docs]
def has_line_data() -> bool:
return _LINE_SIGMA_CACHE is not None
# Function to load TauREx HDF5 opacity data
def _load_line_h5(index: int, path: str, target_wavelengths: np.ndarray) -> LineRegistryEntry:
"""
Load TauREx HDF5 format opacity tables.
TauREx format:
- mol_name: molecule name (string)
- p: pressure grid in bar (nP,)
- t: temperature grid in K (nT,)
- bin_edges: wavenumber bin edges in cm^-1 (nwl+1,)
- xsecarr: cross sections in cm^2/molecule (nP, nT, nwl)
Returns data in registry format:
- pressures in bar (nP,)
- temperatures in K (nT,)
- wavelengths in microns (target_wavelengths,)
- cross_sections in log10(cm^2) (nT, nP, target_wavelengths)
"""
with h5py.File(path, 'r') as f:
# Read molecule name
name = f["mol_name"][0]
if isinstance(name, bytes):
name = name.decode('utf-8')
name = str(name)
# Read grids
pressures = np.asarray(f["p"][:], dtype=float) # bar
temperatures = np.asarray(f["t"][:], dtype=float) # K
bin_edges = np.asarray(f["bin_edges"][:], dtype=float) # cm^-1
native_xs = np.asarray(f["xsecarr"][:], dtype=float) # (nP, nT, nwl) cm^2/molecule
# Dimensions
n_pressures = pressures.size
n_temperatures = temperatures.size
# Convert wavenumber bin edges to wavelength bin centers
# λ[μm] = 10000 / ν[cm^-1]
# Bin centers from edges
wavenumber_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
native_wavelengths = 10000.0 / wavenumber_centers # μm
# Sort wavelengths (wavenumbers are descending, so wavelengths will be ascending after conversion)
sort_idx = np.argsort(native_wavelengths)
native_wavelengths = native_wavelengths[sort_idx]
# Transpose from (nP, nT, nwl) to (nT, nP, nwl) and apply wavelength sorting
native_xs_transposed = np.transpose(native_xs, (1, 0, 2))[:, :, sort_idx]
# Convert to log10 (handle zeros by setting minimum value)
# Use maximum to avoid log10(0) warning
min_xs = 1e-99 # corresponds to log10 = -99
native_xs_log = np.log10(np.maximum(native_xs_transposed, min_xs))
# Dimensions of the master wavelength
wavelength_count = target_wavelengths.size
# Interpolate to target wavelength grid.
# Keep NumPy-side preprocessing in float64; downcast happens at JAX cache transfer.
xs_interp = np.empty((n_temperatures, n_pressures, wavelength_count), dtype=np.float64)
for iT in range(n_temperatures):
for iP in range(n_pressures):
xs_interp[iT, iP, :] = np.interp(
target_wavelengths,
native_wavelengths,
native_xs_log[iT, iP, :],
left=-99.0,
right=-99.0
)
# Return a dataclass with NumPy arrays (will be converted to JAX later).
# Keep NumPy-side arrays float64 for preprocessing consistency.
return LineRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=target_wavelengths.astype(np.float64),
cross_sections=xs_interp.astype(np.float64),
)
def _load_line_zarr(index: int, path: str, target_wavelengths: np.ndarray) -> LineRegistryEntry:
"""
Load opacity tables stored in the custom Zarr format.
Expected Zarr contents:
- attrs["molecule"]: species name (string)
- temperature: temperature grid in K (nT,)
- pressure: pressure grid in bar (nP,)
- wavelength: wavelength grid in microns (nwl,)
- cross_section: log10 cross-sections (nT, nP, nwl)
"""
if path.endswith(".zip"):
from zarr.storage import ZipStore
_store = ZipStore(path, mode="r")
root = zarr.open_group(store=_store, mode="r")
else:
root = zarr.open_group(path, mode="r")
name = str(root.attrs.get("molecule", f"line_{index}"))
temperatures = np.asarray(root["temperature"][:], dtype=float)
pressures = np.asarray(root["pressure"][:], dtype=float)
native_wavelengths = np.asarray(root["wavelength"][:], dtype=float)
xs = np.asarray(root["cross_section"][:], dtype=float)
if xs.ndim != 3:
raise ValueError(f"Invalid cross_section shape {xs.shape} in {path}; expected 3D array.")
n_temperatures, n_pressures, native_wl_count = xs.shape
if native_wavelengths.size != native_wl_count:
raise ValueError(
f"Wavelength grid length {native_wavelengths.size} does not match cross-section axis {native_wl_count} in {path}."
)
target_wavelengths = np.asarray(target_wavelengths, dtype=float)
if target_wavelengths.ndim != 1:
raise ValueError("Target wavelength grid must be 1D.")
if native_wavelengths.shape == target_wavelengths.shape and np.allclose(native_wavelengths, target_wavelengths):
xs_interp = xs
else:
xs_interp = np.empty((n_temperatures, n_pressures, target_wavelengths.size), dtype=np.float64)
for iT in range(n_temperatures):
for iP in range(n_pressures):
xs_interp[iT, iP, :] = np.interp(
target_wavelengths,
native_wavelengths,
xs[iT, iP, :],
left=-99.0,
right=-99.0,
)
return LineRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=target_wavelengths.astype(np.float64),
cross_sections=xs_interp.astype(np.float64),
)
def _load_line_npz(index: int, path: str, target_wavelengths: np.ndarray) -> LineRegistryEntry:
"""
Load opacity tables stored in the custom NPZ format generated by Gen_OS_table_R_zarr.py.
Expected NPZ contents:
- molecule: species name (string or bytes)
- temperature: temperature grid in K (nT,)
- pressure: pressure grid in bar (nP,)
- wavelength: wavelength grid in microns (nwl,)
- cross_section: log10 cross-sections (nT, nP, nwl)
"""
with np.load(path) as data:
name_raw = data["molecule"]
temperatures = np.asarray(data["temperature"], dtype=float)
pressures = np.asarray(data["pressure"], dtype=float)
native_wavelengths = np.asarray(data["wavelength"], dtype=float)
xs = np.asarray(data["cross_section"], dtype=float)
if isinstance(name_raw, np.ndarray):
name = name_raw.tolist()
if isinstance(name, list):
name = name[0]
else:
name = name_raw
if isinstance(name, bytes):
name = name.decode("utf-8")
name = str(name)
if xs.ndim != 3:
raise ValueError(f"Invalid cross_section shape {xs.shape} in {path}; expected 3D array.")
n_temperatures, n_pressures, native_wl_count = xs.shape
if native_wavelengths.size != native_wl_count:
raise ValueError(
f"Wavelength grid length {native_wavelengths.size} does not match cross-section axis {native_wl_count} in {path}."
)
target_wavelengths = np.asarray(target_wavelengths, dtype=float)
if target_wavelengths.ndim != 1:
raise ValueError("Target wavelength grid must be 1D.")
if native_wavelengths.shape == target_wavelengths.shape and np.allclose(native_wavelengths, target_wavelengths):
xs_interp = xs
else:
# Keep NumPy-side preprocessing in float64; downcast happens at JAX cache transfer.
xs_interp = np.empty((n_temperatures, n_pressures, target_wavelengths.size), dtype=np.float64)
for iT in range(n_temperatures):
for iP in range(n_pressures):
xs_interp[iT, iP, :] = np.interp(
target_wavelengths,
native_wavelengths,
xs[iT, iP, :],
left=-99.0,
right=-99.0,
)
# Return NumPy arrays (will be converted to JAX later).
# Keep NumPy-side arrays float64 for preprocessing consistency.
return LineRegistryEntry(
name=name,
idx=index,
pressures=pressures.astype(np.float64),
temperatures=temperatures.astype(np.float64),
wavelengths=target_wavelengths.astype(np.float64),
cross_sections=xs_interp.astype(np.float64),
)
# Pad the tables to a rectangle (in dimension) - usually only in T as wavelength and pressure grids are the same lengths
# Uses NumPy for preprocessing (CPU-based padding before sending to device)
def _rectangularize_entries(entries: List[LineRegistryEntry]) -> Tuple[LineRegistryEntry, ...]:
# Return if zero OS table
if not entries:
return ()
# Find the wavelength and pressure grid from the first tables (should be the same across all species)
base_wavelengths = entries[0].wavelengths
base_pressures = entries[0].pressures
expected_wavelengths = base_wavelengths.shape[0]
for entry in entries[1:]:
if entry.wavelengths.shape != base_wavelengths.shape or not np.allclose(entry.wavelengths, base_wavelengths):
raise ValueError(f"Line opacity wavelength grids differ between {entries[0].name} and {entry.name}.")
if entry.pressures.shape != base_pressures.shape or not np.allclose(entry.pressures, base_pressures):
raise ValueError(f"Line opacity pressure grids differ between {entries[0].name} and {entry.name}.")
# Find the max number of pressure points
max_pressures = max(entry.pressures.shape[0] for entry in entries)
# Find the max number of temperature points
max_temperatures = max(entry.temperatures.shape[0] for entry in entries)
# Start a new list for the padded cross section tables and pad the temperature arrays with the extra dimensions
padded_entries: List[LineRegistryEntry] = []
for entry in entries:
# Keep as NumPy arrays for preprocessing
pressures = entry.pressures
temperatures = entry.temperatures
xs = entry.cross_sections
current_temperatures, current_pressures, wavelength_count = xs.shape
if wavelength_count != expected_wavelengths:
raise ValueError(f"Species {entry.name} has λ grid length {wavelength_count}, expected {expected_wavelengths}.")
if current_pressures != max_pressures:
raise ValueError(f"Species {entry.name} has nP={current_pressures}, expected {max_pressures} for common grid.")
pad_temperatures = max_temperatures - current_temperatures
if pad_temperatures > 0:
# Use NumPy padding (CPU-based)
temperatures = np.pad(temperatures, (0, pad_temperatures), mode="edge")
xs = np.pad(xs, ((0, pad_temperatures), (0, 0), (0, 0)), mode="edge")
padded_entries.append(
LineRegistryEntry(
name=entry.name,
idx=entry.idx,
pressures=pressures,
temperatures=temperatures,
wavelengths=base_wavelengths,
cross_sections=xs,
)
)
return tuple(padded_entries)
# Read in and prepare the line data
[docs]
def load_line_registry(cfg, obs, lam_master: Optional[np.ndarray] = None, base_dir: Optional[Path] = None):
# Allocate the global scope caches
global _LINE_SPECIES_NAMES, _LINE_SIGMA_CACHE, _LINE_TEMPERATURE_CACHE, _LINE_WAVELENGTH_CACHE, _LINE_PRESSURE_CACHE
global _LINE_LOG10_TEMPERATURE_CACHE, _LINE_LOG10_PRESSURE_CACHE
entries: List[LineRegistryEntry] = []
config = getattr(cfg.opac, "line", None)
if not config:
reset_registry()
return
# Use the observational wavelengths to interpolate to if no master grid is present
wavelengths = np.asarray(obs["wl"], dtype=float) if lam_master is None else np.asarray(lam_master, dtype=float)
# Read in the line data for each species given by the YAML file - add to the entries list
for index, spec in enumerate(cfg.opac.line):
path = Path(spec.path).expanduser()
if not path.is_absolute():
if base_dir is not None:
path = (Path(base_dir) / path).resolve()
else:
path = path.resolve()
path_str = str(path)
print("[Line] Reading line xs for", spec.species, "@", path_str)
# Check file format
if path_str.endswith(".npz"):
entry = _load_line_npz(index, path_str, wavelengths)
elif path_str.endswith('.h5') or path_str.endswith('.hdf5'):
entry = _load_line_h5(index, path_str, wavelengths)
elif path_str.endswith('.zarr') or path_str.endswith('.zarr.zip'):
if path_str.endswith('.zarr') and not path.exists():
zip_fallback = Path(path_str + '.zip')
if zip_fallback.exists():
path_str = str(zip_fallback)
print(f"[Line] .zarr directory not found; using {zip_fallback.name}")
entry = _load_line_zarr(index, path_str, wavelengths)
else:
raise ValueError(f"Unsupported file format for {path_str}. Expected .npz, .h5, .hdf5, .zarr or .zarr.zip")
entries.append(entry)
# Now need to pad in the temperature dimension to make all grids to the same size (for JAX)
rectangularized_entries = _rectangularize_entries(entries)
if not rectangularized_entries:
reset_registry()
return
# ============================================================================
# CRITICAL: Convert NumPy arrays to JAX arrays here (ONE transfer to device)
# ============================================================================
# All preprocessing is done in NumPy (CPU). Now we send the final data
# to the device (GPU/CPU as configured) for use in JIT-compiled forward model.
# Mixed precision strategy:
# - float64 for grids (pressures, temperatures, wavelengths) → better interpolation accuracy
# - float32 for cross sections → halves memory usage (~4 GB → ~2 GB for large grids)
# ============================================================================
print(f"[Line] Transferring {len(rectangularized_entries)} species to device...")
# Stack cross sections: (n_species, nT, nP, nwl) - already float32 from preprocessing
sigma_stacked = np.stack([entry.cross_sections for entry in rectangularized_entries], axis=0)
_LINE_SIGMA_CACHE = jnp.asarray(sigma_stacked, dtype=jnp.float32)
# Stack temperature grids: (n_species, nT) - keep as float64 for accuracy
temp_stacked = np.stack([entry.temperatures for entry in rectangularized_entries], axis=0)
_LINE_TEMPERATURE_CACHE = jnp.asarray(temp_stacked, dtype=jnp.float64)
# Wavelength and pressure grids: all species share the same grids (rectangularized)
_LINE_WAVELENGTH_CACHE = jnp.asarray(rectangularized_entries[0].wavelengths, dtype=jnp.float64)
_LINE_PRESSURE_CACHE = jnp.asarray(rectangularized_entries[0].pressures, dtype=jnp.float64)
# Pre-compute log10 of grids for efficient interpolation
_LINE_LOG10_PRESSURE_CACHE = jnp.log10(_LINE_PRESSURE_CACHE)
_LINE_LOG10_TEMPERATURE_CACHE = jnp.log10(_LINE_TEMPERATURE_CACHE)
print(f"[Line] Cross section cache: {_LINE_SIGMA_CACHE.shape} (dtype: {_LINE_SIGMA_CACHE.dtype})")
print(f"[Line] Temperature cache: {_LINE_TEMPERATURE_CACHE.shape} (dtype: {_LINE_TEMPERATURE_CACHE.dtype})")
print(f"[Line] Cached log10(P) and log10(T) grids for efficient interpolation")
# Estimate memory usage
sigma_mb = _LINE_SIGMA_CACHE.size * _LINE_SIGMA_CACHE.itemsize / 1024**2
temp_mb = _LINE_TEMPERATURE_CACHE.size * _LINE_TEMPERATURE_CACHE.itemsize / 1024**2
wl_mb = _LINE_WAVELENGTH_CACHE.size * _LINE_WAVELENGTH_CACHE.itemsize / 1024**2
p_mb = _LINE_PRESSURE_CACHE.size * _LINE_PRESSURE_CACHE.itemsize / 1024**2
total_mb = sigma_mb + temp_mb + wl_mb + p_mb
print(f"[Line] Estimated device memory: {total_mb:.1f} MB (σ: {sigma_mb:.1f} MB, T: {temp_mb:.2f} MB, λ: {wl_mb:.2f} MB, P: {p_mb:.2f} MB)")
# Extract species names (lightweight: just strings)
_LINE_SPECIES_NAMES = tuple(entry.name for entry in rectangularized_entries)
# Delete NumPy arrays to free memory (JAX caches now hold the data on device)
# This saves ~100s of MB for typical opacity tables
del rectangularized_entries, entries, sigma_stacked, temp_stacked
print(f"[Line] Freed NumPy temporary arrays from CPU memory")
_clear_cache()
### -- lru cached helper functions below --- ###
[docs]
@lru_cache(None)
def line_species_names() -> Tuple[str, ...]:
if not _LINE_SPECIES_NAMES:
raise RuntimeError("Line registry empty; call build_opacities() first.")
return _LINE_SPECIES_NAMES
[docs]
@lru_cache(None)
def line_runtime_species_order() -> Tuple[str, ...]:
return line_species_names()
[docs]
@lru_cache(None)
def line_master_wavelength() -> jnp.ndarray:
if _LINE_WAVELENGTH_CACHE is None:
raise RuntimeError("Line registry empty; call build_opacities() first.")
return _LINE_WAVELENGTH_CACHE
[docs]
@lru_cache(None)
def line_pressure_grid() -> jnp.ndarray:
if _LINE_PRESSURE_CACHE is None:
raise RuntimeError("Line registry empty; call build_opacities() first.")
return _LINE_PRESSURE_CACHE
[docs]
@lru_cache(None)
def line_temperature_grids() -> jnp.ndarray:
if _LINE_TEMPERATURE_CACHE is None:
raise RuntimeError("Line temperature grids not built; call build_opacities() first.")
return _LINE_TEMPERATURE_CACHE
[docs]
@lru_cache(None)
def line_temperature_grid() -> jnp.ndarray:
return line_temperature_grids()[0]
[docs]
@lru_cache(None)
def line_sigma_cube() -> jnp.ndarray:
if _LINE_SIGMA_CACHE is None:
raise RuntimeError("Line σ cube not built; call build_opacities() first.")
return _LINE_SIGMA_CACHE
[docs]
@lru_cache(None)
def line_log10_pressure_grid() -> jnp.ndarray:
if _LINE_LOG10_PRESSURE_CACHE is None:
raise RuntimeError("Line log10(P) grid not built; call build_opacities() first.")
return _LINE_LOG10_PRESSURE_CACHE
[docs]
@lru_cache(None)
def line_log10_temperature_grids() -> jnp.ndarray:
if _LINE_LOG10_TEMPERATURE_CACHE is None:
raise RuntimeError("Line log10(T) grids not built; call build_opacities() first.")
return _LINE_LOG10_TEMPERATURE_CACHE