"""
registry_ray.py
===============
"""
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Tuple, Optional
import numpy as np
import jax.numpy as jnp
__all__ = [
"RayRegistryEntry",
"reset_registry",
"has_ray_data",
"load_ray_registry",
"ray_species_names",
"ray_master_wavelength",
"ray_sigma_table",
"ray_nm1_table",
"ray_nd_ref",
"ray_runtime_species_order",
"ray_sigma_linear_table",
"ray_refractivity_coeff_table",
"ray_pick_arrays",
]
# Dataclass for the Rayleigh cross section data
# Note: During preprocessing, all arrays are NumPy (CPU)
# They get converted to JAX (device) only at the final cache creation step
# Float64 for wavelengths, float32 for cross sections.
[docs]
@dataclass(frozen=True)
class RayRegistryEntry:
name: str
idx: int
wavelengths: np.ndarray # NumPy during preprocessing (float64)
cross_sections: np.ndarray # NumPy during preprocessing (float64)
# Global Rayleigh cross section caches
_RAY_ENTRIES: Tuple[RayRegistryEntry, ...] = ()
_RAY_SIGMA_CACHE: jnp.ndarray | None = None
_RAY_WAVELENGTH_CACHE: jnp.ndarray | None = None
_RAY_NM1_CACHE: jnp.ndarray | None = None
_RAY_NDREF_CACHE: jnp.ndarray | None = None
_RAY_SPECIES_NAMES: Tuple[str, ...] = ()
_RAY_SIGMA_LINEAR_CACHE: jnp.ndarray | None = None
_RAY_REFRACTIVITY_COEFF_CACHE: jnp.ndarray | None = None
# Some required constants
PI = np.pi
C_LIGHT = 2.99792458e10
SIGMA_T = 6.6524587321e-25
WL_LY_CM = 121.567e-7
F_LY = C_LIGHT / WL_LY_CM
W_L = (2.0 * PI * F_LY) / 0.75
CP = np.array([1.26537, 3.73766, 8.8127, 19.1515, 39.919, 81.1018, 161.896, 319.001, 622.229, 1203.82])
N_STP_AIR = 2.68678e19
N_STP_2547 = 2.546899e19
N_STP_H2 = 2.65163e19
# Clear cache helper functions
def _clear_cache():
ray_species_names.cache_clear()
ray_runtime_species_order.cache_clear()
ray_master_wavelength.cache_clear()
ray_sigma_table.cache_clear()
ray_nm1_table.cache_clear()
ray_nd_ref.cache_clear()
ray_sigma_linear_table.cache_clear()
ray_refractivity_coeff_table.cache_clear()
ray_pick_arrays.cache_clear()
# Clear global data helper function
[docs]
def reset_registry():
global _RAY_ENTRIES, _RAY_SIGMA_CACHE, _RAY_WAVELENGTH_CACHE, _RAY_NM1_CACHE, _RAY_NDREF_CACHE
global _RAY_SPECIES_NAMES, _RAY_SIGMA_LINEAR_CACHE, _RAY_REFRACTIVITY_COEFF_CACHE
_RAY_ENTRIES = ()
_RAY_SIGMA_CACHE = None
_RAY_WAVELENGTH_CACHE = None
_RAY_NM1_CACHE = None
_RAY_NDREF_CACHE = None
_RAY_SPECIES_NAMES = ()
_RAY_SIGMA_LINEAR_CACHE = None
_RAY_REFRACTIVITY_COEFF_CACHE = None
_clear_cache()
# Check if Rayleigh data exists helper functions
[docs]
def has_ray_data() -> bool:
return bool(_RAY_ENTRIES)
# Functions to calculate index n and King factor (same as gCMCRT)
def _n_func(wn: np.ndarray, A: float, B: float, C: float) -> np.ndarray:
nm1 = A + B / (C - wn**2)
return nm1 / 1.0e8 + 1.0
def _n_func2(wl_um: np.ndarray, A: float, B: float) -> np.ndarray:
nm1 = A * (1.0 + B / (wl_um**2))
return nm1 + 1.0
def _king_from_Dpol_1(Dpol: float) -> float:
return (6.0 + 3.0 * Dpol) / (6.0 - 7.0 * Dpol)
def _king_from_Dpol_2(Dpol: float) -> float:
return (3.0 + 6.0 * Dpol) / (3.0 - 4.0 * Dpol)
# Special H Rayleigh scattering calculation
def _sigma_H(freq: np.ndarray, wl_A: np.ndarray) -> np.ndarray:
w = 2.0 * PI * freq
wwl = w / W_L
xsec = np.zeros_like(wl_A)
mask_low = wwl <= 0.6
if np.any(mask_low):
x = wwl[mask_low]
poly = np.zeros_like(x)
for p in range(CP.size):
poly += CP[p] * x**(2 * p)
xsec[mask_low] = poly * x**4
mask_high = ~mask_low
if np.any(mask_high):
w_h = w[mask_high]
wb = (w_h - 0.75 * W_L) / (0.75 * W_L)
term = 1.0 - 1.792 * wb - 23.637 * wb**2 - 83.1393 * wb**3 - 244.1453 * wb**4 - 699.473 * wb**5
xsec[mask_high] = 0.0433056 / (wb**2) * term
xsec *= SIGMA_T
return xsec
# Go through each species and calculate the Rayleigh scattering cross sections (same as gCMCRT)
def _compute_species_sigma(name: str, wl_um: np.ndarray) -> np.ndarray:
xsec, _, _ = _compute_species_sigma_nm1_ndref(name, wl_um)
return xsec
def _compute_species_sigma_nm1_ndref(
name: str,
wl_um: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, float]:
name = name.strip()
wl_um = np.asarray(wl_um, dtype=float)
wn = 1.0e4 / wl_um
wl_A = wl_um * 1.0e4
wl_cm = wl_um * 1.0e-4
freq = C_LIGHT / wl_cm
if name == "H2":
n = _n_func2(wl_um, 13.58e-5, 7.52e-3)
King = 1.0
nd_stp = N_STP_H2
elif name == "He":
n = _n_func(wn, 2283.0, 1.8102e13, 1.5342e10)
King = 1.0
nd_stp = N_STP_2547
elif name in ("e-", "el"):
xsec = np.full_like(wl_um, SIGMA_T)
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "H":
xsec = _sigma_H(freq, wl_A)
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "CO":
n = _n_func(wn, 22851.0, 0.456e14, 71427.0**2)
King = 1.0
nd_stp = N_STP_2547
elif name == "CO2":
n = 1.1427e3 * (
5799.25 / (128908.9**2 - wn**2)
+ 120.05 / (89223.8**2 - wn**2)
+ 5.3334 / (75037.5**2 - wn**2)
+ 4.3244 / (67837.7**2 - wn**2)
+ 0.1218145e-6 / (2418.136**2 - wn**2)
)
n = n + 1.0
King = 1.1364 + 25.3e-12 * wn**2
nd_stp = N_STP_2547
elif name == "CH4":
n = (46662.0 + 4.02e-6 * wn**2) / 1.0e8 + 1.0
King = 1.0
nd_stp = N_STP_2547
elif name == "O2":
n = _n_func(wn, 20564.8, 2.480899e13, 4.09e9)
King = 1.09 + 1.385e-11 * wn**2 + 1.448e-20 * wn**4
nd_stp = N_STP_AIR
elif name == "N2":
A_hi, B_hi, C_hi = 5677.465, 318.81874e12, 14.4e9
A_lo, B_lo, C_lo = 6498.2, 307.4335e12, 14.4e9
mask_hi = wn > 21360.0
n = np.empty_like(wn)
n[mask_hi] = _n_func(wn[mask_hi], A_hi, B_hi, C_hi)
n[~mask_hi] = _n_func(wn[~mask_hi], A_lo, B_lo, C_lo)
King = 1.034 + 3.17e-12 * wn
nd_stp = N_STP_2547
elif name == "NH3":
n = _n_func2(wl_um, 37.0e-5, 12.0e-3)
King = _king_from_Dpol_1(0.0922)
nd_stp = N_STP_AIR
elif name == "Ar":
n = _n_func(wn, 6432.135, 286.06021e12, 14.4e9)
King = 1.0
nd_stp = N_STP_2547
elif name == "N2O":
n = (46890.0 + 4.12e-6 * wn**2) / 1.0e8 + 1.0
Dpol = 0.0577 + 11.8e-12 * wn**2
King = _king_from_Dpol_2(Dpol)
nd_stp = N_STP_2547
elif name == "SF6":
n = (71517.0 + 4.996e-6 * wn**2) / 1.0e8 + 1.0
King = 1.0
nd_stp = N_STP_2547
elif name == "HCl":
a_vol = 2.515 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "HCN":
a_vol = 2.593 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "H2S":
a_vol = 3.631 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "OCS":
a_vol = 5.090 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "SO2":
a_vol = 3.882 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "C2H2":
a_vol = 3.487 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "PH3":
a_vol = 4.237 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "SO3":
a_vol = 4.297 / (1.0e8**3)
King = 1.0
xsec = (128.0 / 3.0) * PI**5 * a_vol**2 * wn**4 * King
nm1 = np.zeros_like(wl_um)
return xsec, nm1, 1.0
elif name == "H2O":
raise NotImplementedError("Layer-dependent H2O Rayleigh must be handled in opacity calculations.")
else:
raise ValueError(f"Unsupported Rayleigh species '{name}' in ray registry.")
xsec = ((24.0 * PI**3 * wn**4) / (nd_stp**2)) * (((n**2 - 1.0) / (n**2 + 2.0))**2) * King
nm1 = np.maximum(n - 1.0, 0.0)
return np.maximum(xsec, 1.0e-99), nm1, float(nd_stp)
# Calculate and set the global Rayleigh cross section data caches
[docs]
def load_ray_registry(cfg, obs, lam_master: Optional[np.ndarray] = None) -> None:
global _RAY_ENTRIES, _RAY_SIGMA_CACHE, _RAY_WAVELENGTH_CACHE, _RAY_NM1_CACHE, _RAY_NDREF_CACHE
global _RAY_SPECIES_NAMES, _RAY_SIGMA_LINEAR_CACHE, _RAY_REFRACTIVITY_COEFF_CACHE
entries: List[RayRegistryEntry] = []
nm1_entries: List[np.ndarray] = []
ndref_entries: List[float] = []
config = getattr(cfg.opac, "ray", None)
if not config:
reset_registry()
return
wavelengths = np.asarray(obs["wl"], dtype=float) if lam_master is None else np.asarray(lam_master, dtype=float)
for index, spec in enumerate(cfg.opac.ray):
name = getattr(spec, "species", str(spec))
print("[Ray] Computing Rayleigh xs for", name)
xs, nm1, nd_ref = _compute_species_sigma_nm1_ndref(name, wavelengths)
log_xs = np.log10(xs)
# Create entry with NumPy arrays (will be converted to JAX later)
# Float64 for wavelengths and cross sections.
entries.append(
RayRegistryEntry(
name=name,
idx=index,
wavelengths=wavelengths.astype(np.float64),
cross_sections=log_xs.astype(np.float64),
)
)
nm1_entries.append(nm1.astype(np.float64))
ndref_entries.append(float(nd_ref))
_RAY_ENTRIES = tuple(entries)
if not _RAY_ENTRIES:
reset_registry()
return
# ============================================================================
# CRITICAL: Convert NumPy arrays to JAX arrays here (ONE transfer to device)
# ============================================================================
# All preprocessing is done in NumPy (CPU). Now we send the final data
# to the device (GPU/CPU as configured) for use in JIT-compiled forward model.
# Float64 strategy for wavelengths and cross sections to keep dtype consistent.
# ============================================================================
print(f"[Ray] Transferring {len(_RAY_ENTRIES)} species to device...")
# Stack cross sections: (n_species, nwl) - already float64 from preprocessing
sigma_stacked = np.stack([entry.cross_sections for entry in _RAY_ENTRIES], axis=0)
_RAY_SIGMA_CACHE = jnp.asarray(sigma_stacked, dtype=jnp.float32)
_RAY_WAVELENGTH_CACHE = jnp.asarray(_RAY_ENTRIES[0].wavelengths, dtype=jnp.float64)
nm1_stacked = np.stack(nm1_entries, axis=0)
_RAY_NM1_CACHE = jnp.asarray(nm1_stacked, dtype=jnp.float32)
_RAY_NDREF_CACHE = jnp.asarray(np.asarray(ndref_entries, dtype=np.float64), dtype=jnp.float64)
_RAY_SPECIES_NAMES = tuple(entry.name for entry in _RAY_ENTRIES)
_RAY_SIGMA_LINEAR_CACHE = 10.0 ** _RAY_SIGMA_CACHE.astype(jnp.float64)
_RAY_REFRACTIVITY_COEFF_CACHE = _RAY_NM1_CACHE.astype(jnp.float64) / _RAY_NDREF_CACHE[:, None]
print(f"[Ray] Cross section cache: {_RAY_SIGMA_CACHE.shape} (dtype: {_RAY_SIGMA_CACHE.dtype})")
# Estimate memory usage
sigma_mb = _RAY_SIGMA_CACHE.size * _RAY_SIGMA_CACHE.itemsize / 1024**2
print(f"[Ray] Estimated device memory: {sigma_mb:.1f} MB")
_clear_cache()
### -- lru cached helper functions below --- ###
[docs]
@lru_cache(None)
def ray_species_names() -> Tuple[str, ...]:
return _RAY_SPECIES_NAMES
[docs]
@lru_cache(None)
def ray_runtime_species_order() -> Tuple[str, ...]:
return ray_species_names()
[docs]
@lru_cache(None)
def ray_master_wavelength() -> jnp.ndarray:
if _RAY_WAVELENGTH_CACHE is None:
raise RuntimeError("Rayleigh registry empty; call build_opacities() first.")
return _RAY_WAVELENGTH_CACHE
[docs]
@lru_cache(None)
def ray_sigma_table() -> jnp.ndarray:
if _RAY_SIGMA_CACHE is None:
raise RuntimeError("Rayleigh σ table not built; call build_opacities() first.")
return _RAY_SIGMA_CACHE
[docs]
@lru_cache(None)
def ray_nm1_table() -> jnp.ndarray:
if _RAY_NM1_CACHE is None:
raise RuntimeError("Rayleigh (n-1) table not built; call build_opacities() first.")
return _RAY_NM1_CACHE
[docs]
@lru_cache(None)
def ray_nd_ref() -> jnp.ndarray:
if _RAY_NDREF_CACHE is None:
raise RuntimeError("Rayleigh reference number density table not built; call build_opacities() first.")
return _RAY_NDREF_CACHE
[docs]
@lru_cache(None)
def ray_sigma_linear_table() -> jnp.ndarray:
if _RAY_SIGMA_LINEAR_CACHE is None:
raise RuntimeError("Rayleigh linear sigma table not built; call build_opacities() first.")
return _RAY_SIGMA_LINEAR_CACHE
[docs]
@lru_cache(None)
def ray_refractivity_coeff_table() -> jnp.ndarray:
if _RAY_REFRACTIVITY_COEFF_CACHE is None:
raise RuntimeError("Rayleigh refractivity coefficient table not built; call build_opacities() first.")
return _RAY_REFRACTIVITY_COEFF_CACHE
[docs]
@lru_cache(None)
def ray_pick_arrays():
if _RAY_SIGMA_CACHE is None:
raise RuntimeError("Rayleigh registry empty; call build_opacities() first.")
n_species = int(_RAY_SIGMA_CACHE.shape[0])
wavelengths = ray_master_wavelength()
sigma = ray_sigma_table()
picks_wavelengths = tuple((lambda _=None, wl=wavelengths: wl) for _ in range(n_species))
picks_sigma = tuple((lambda _=None, xs=sigma[i]: xs) for i in range(n_species))
return picks_wavelengths, picks_sigma