Source code for exo_skryer.opacity_cia

"""
opacity_cia.py
==============
"""

from typing import Dict

import jax.numpy as jnp
from jax import lax

from . import build_opacities as XS

__all__ = [
    "compute_cia_opacity",
    "zero_cia_opacity"
]


[docs] def zero_cia_opacity(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Return a zero CIA opacity array. This function is used as a fallback when no CIA species are enabled or when all CIA pairs are filtered out (e.g., H- opacity handled separately). Parameters ---------- state : dict[str, jnp.ndarray] State dictionary containing: - `nlay` : int-like Number of atmospheric layers. - `nwl` : int-like Number of wavelength points. params : dict[str, jnp.ndarray] Parameter dictionary (unused; kept for API compatibility with other opacity calculation functions). Returns ------- zeros : `~jax.numpy.ndarray`, shape (nlay, nwl) Zero-valued CIA opacity array in cm² g⁻¹. """ # Use shape directly without jnp.size() for JIT compatibility shape = (state["nlay"], state["nwl"]) return jnp.zeros(shape)
[docs] def compute_cia_opacity(state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Compute collision-induced absorption (CIA) mass opacity for all molecular pairs. This function calculates the total CIA opacity by summing contributions from all enabled molecular pairs (e.g., H2-He, H2-H2). For each pair, it: 1. Interpolates pre-loaded cross-sections to layer temperatures 2. Computes the VMR pair product (f_A × f_B) 3. Applies the opacity formula: κ = f_A × f_B × (n_d)² / ρ × σ(λ, T) 4. Sums over all pairs to get total CIA opacity Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] Atmospheric state dictionary containing: - `nlay` : int Number of atmospheric layers. - `nwl` : int Number of wavelength points. - `wl` : `~jax.numpy.ndarray`, shape (nwl,) Wavelength grid in microns (must match CIA table wavelengths). - `T_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer temperatures in Kelvin. - `nd_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer number density in molecule cm⁻³. - `rho_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer mass density in g cm⁻³. - `vmr_lay` : dict[str, `~jax.numpy.ndarray`] Volume mixing ratios for each species. Values can be scalars or arrays with shape (nlay,). params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; kept for API compatibility with other opacity functions that may depend on retrieval parameters). Returns ------- kappa_cia : `~jax.numpy.ndarray`, shape (nlay, nwl) Total CIA mass opacity in cm² g⁻¹, summed over all molecular pairs. Returns zeros if no CIA pairs are enabled. Raises ------ ValueError If the CIA wavelength grid does not match the forward model master grid. """ # Use JAX array directly without int() for JIT compatibility layer_count = state["nlay"] wavelengths = state["wl"] layer_temperatures = state["T_lay"] number_density = state["nd_lay"] # (nlay,) density = state["rho_lay"] # (nlay,) layer_vmr = state["vmr_lay"] master_wavelength = opac["cia_master_wavelength"] if master_wavelength.shape != wavelengths.shape: raise ValueError("CIA wavelength grid must match the forward-model master grid.") species_order = XS.cia_runtime_species_order() if not species_order: return zero_cia_opacity(state, params) sigma_cube = opac["cia_retained_sigma_cube"] # (npairs, nT, nwl) in log10 log_temperature_grids = opac["cia_retained_log10_temperature_grids"] temperature_grids = opac["cia_retained_temperature_grids"] pair_i = opac["cia_pair_species_i"] pair_j = opac["cia_pair_species_j"] log_t_layers = jnp.log10(layer_temperatures) weights_nd2_over_rho = (number_density**2 / density) # (nlay,) out = jnp.zeros((layer_count, wavelengths.shape[0]), dtype=jnp.float64) vmr_stack = jnp.stack( [jnp.broadcast_to(layer_vmr[species], (layer_count,)) for species in species_order], axis=0, ) def _accumulate_one(i: jnp.ndarray, acc: jnp.ndarray) -> jnp.ndarray: sigma_log_table = sigma_cube[i] # (nT, nwl) log_temp_grid = log_temperature_grids[i] # (nT,) temp_grid = temperature_grids[i] # (nT,) t_idx = jnp.searchsorted(log_temp_grid, log_t_layers) - 1 t_idx = jnp.clip(t_idx, 0, log_temp_grid.shape[0] - 2) t_weight = (log_t_layers - log_temp_grid[t_idx]) / (log_temp_grid[t_idx + 1] - log_temp_grid[t_idx]) t_weight = jnp.clip(t_weight, 0.0, 1.0) s_t0 = sigma_log_table[t_idx, :] s_t1 = sigma_log_table[t_idx + 1, :] s_interp = (1.0 - t_weight)[:, None] * s_t0 + t_weight[:, None] * s_t1 min_temp = temp_grid[0] below_min = layer_temperatures < min_temp tiny = jnp.array(-199.0, dtype=s_interp.dtype) s_interp = jnp.where(below_min[:, None], tiny, s_interp) sigma_val = 10.0 ** s_interp.astype(jnp.float64) # (nlay, nwl) pair_weight = vmr_stack[pair_i[i]] * vmr_stack[pair_j[i]] # (nlay,) normalization = pair_weight * weights_nd2_over_rho # (nlay,) return acc + normalization[:, None] * sigma_val out = lax.fori_loop( 0, sigma_cube.shape[0], _accumulate_one, out, ) return out