Source code for exo_skryer.opacity_ck

"""
opacity_ck.py
=============
"""

from typing import Dict

import jax
import jax.numpy as jnp
from jax import lax

from . import build_opacities as XS
from .data_constants import amu, bar
from .ck_mix_RORR import mix_k_tables_rorr
from .ck_mix_PRAS import mix_k_tables_pras

__all__ = [
    "zero_ck_opacity",
    "compute_ck_opacity",
    "compute_ck_opacity_perspecies"
]


def _interpolate_sigma_log(layer_pressures_bar: jnp.ndarray, layer_temperatures: jnp.ndarray) -> jnp.ndarray:
    """Bilinear interpolation of correlated-k cross-sections on (log P, log T) grids.

    This function retrieves pre-loaded correlated-k opacity tables from the opacity
    registry and interpolates them to the specified atmospheric layer conditions using
    bilinear interpolation in log₁₀(P)-log₁₀(T) space. The interpolation is performed
    separately for each species and returns cross-sections still in log₁₀ space.

    Parameters
    ----------
    layer_pressures_bar : `~jax.numpy.ndarray`, shape (nlay,)
        Atmospheric layer pressures in bar.
    layer_temperatures : `~jax.numpy.ndarray`, shape (nlay,)
        Atmospheric layer temperatures in Kelvin.

    Returns
    -------
    sigma_interp : `~jax.numpy.ndarray`, shape (nspecies, nlay, nwl, ng)
        Interpolated cross-sections in log₁₀ space with units of log₁₀(cm² molecule⁻¹).
        The axes represent:
        - nspecies: Number of absorbing species
        - nlay: Number of atmospheric layers
        - nwl: Number of wavelength bins
        - ng: Number of g-points per wavelength bin
    """
    # NOTE: This helper pulls tables from the global registry. If used inside a large
    # jitted forward model, it may increase compile-time constants. Prefer
    # compute_ck_opacity() which supports passing tables via `state`.
    sigma_cube = XS.ck_sigma_cube()
    log_p_grid = XS.ck_log10_pressure_grid()
    log_temperature_grids = XS.ck_log10_temperature_grids()

    log_p_layers = jnp.log10(layer_pressures_bar)
    log_t_layers = jnp.log10(layer_temperatures)

    def _interp_one_layer(log_p: jnp.ndarray, log_t: jnp.ndarray) -> jnp.ndarray:
        # Pressure bracket indices and weights (shared across species)
        p_idx = jnp.searchsorted(log_p_grid, log_p) - 1
        p_idx = jnp.clip(p_idx, 0, log_p_grid.shape[0] - 2)
        p_weight = (log_p - log_p_grid[p_idx]) / (log_p_grid[p_idx + 1] - log_p_grid[p_idx])
        p_weight = jnp.clip(p_weight, 0.0, 1.0)

        def _interp_one_species(sigma_4d: jnp.ndarray, log_temp_grid: jnp.ndarray) -> jnp.ndarray:
            # Temperature bracket indices and weights (species-dependent)
            t_idx = jnp.searchsorted(log_temp_grid, log_t) - 1
            t_idx = jnp.clip(t_idx, 0, log_temp_grid.shape[0] - 2)
            t_weight = (log_t - log_temp_grid[t_idx]) / (log_temp_grid[t_idx + 1] - log_temp_grid[t_idx])
            t_weight = jnp.clip(t_weight, 0.0, 1.0)

            # Bilinear interpolation in (logT, logP). Indices are scalars here, so outputs are (nwl, ng).
            s_t0_p0 = sigma_4d[t_idx, p_idx, :, :]
            s_t0_p1 = sigma_4d[t_idx, p_idx + 1, :, :]
            s_t1_p0 = sigma_4d[t_idx + 1, p_idx, :, :]
            s_t1_p1 = sigma_4d[t_idx + 1, p_idx + 1, :, :]

            s_t0 = (1.0 - p_weight) * s_t0_p0 + p_weight * s_t0_p1
            s_t1 = (1.0 - p_weight) * s_t1_p0 + p_weight * s_t1_p1
            return (1.0 - t_weight) * s_t0 + t_weight * s_t1

        sigma_log_layer = jax.vmap(_interp_one_species)(sigma_cube, log_temperature_grids)  # (nspec, nwl, ng)
        return sigma_log_layer

    # NOTE: This returns a large (nspec, nlay, nwl, ng) array; callers that care about peak
    # memory should perform layer-wise interpolation + mixing instead of using this helper.
    return jax.vmap(_interp_one_layer)(log_p_layers, log_t_layers)

def _get_ck_quadrature(opac: Dict[str, jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Extract g-points and quadrature weights for correlated-k integration.

    This function retrieves the g-points and their associated quadrature weights
    used for integrating over the k-distribution. The g-points represent cumulative
    probability values in [0, 1], and the weights are typically Gaussian quadrature
    weights that sum to 1.

    Parameters
    ----------
    state : dict[str, `~jax.numpy.ndarray`]
        State dictionary that may contain pre-loaded 'g_weights' array.
        If not present, weights are retrieved from the opacity registry.

    Returns
    -------
    g_points : `~jax.numpy.ndarray`, shape (ng,)
        Cumulative probability points where k-distribution is sampled, in [0, 1].
    weights : `~jax.numpy.ndarray`, shape (ng,)
        Quadrature weights for numerical integration over g-space. Sum to 1.0.
    """
    return opac["g_points"], opac["g_weights"]


[docs] def zero_ck_opacity(state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Return a zero correlated-k opacity array. This function is used as a fallback when correlated-k opacities are disabled in the configuration. It maintains API compatibility with `compute_ck_opacity()` so the forward model can seamlessly switch between CK enabled/disabled. Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] State dictionary containing: - `p_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer pressures (used only to determine array size). - `wl` : `~jax.numpy.ndarray`, shape (nwl,) Wavelength grid (used only to determine array size). params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; kept for API compatibility). Returns ------- zeros : `~jax.numpy.ndarray`, shape (nlay, nwl, ng) Zero-valued correlated-k opacity array in cm² g⁻¹. """ layer_pressures = state["p_lay"] wavelengths = state["wl"] layer_count = layer_pressures.shape[0] wavelength_count = wavelengths.shape[0] # Get number of g-points from opac cache. g_weights = opac["g_weights"] n_g = g_weights.shape[-1] return jnp.zeros((layer_count, wavelength_count, n_g))
[docs] def compute_ck_opacity(state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Compute correlated-k opacity with multi-species mixing. This function calculates the total atmospheric opacity using the correlated-k approximation. It performs the following steps: 1. Interpolates pre-loaded k-tables to atmospheric (P, T) conditions 2. Mixes k-distributions from multiple species using RORR or PRAS scheme 3. Converts from cross-section (cm² molecule⁻¹) to mass opacity (cm² g⁻¹) Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] Atmospheric state dictionary containing: - `p_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer pressures in dyne cm⁻². - `T_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer temperatures in Kelvin. - `mu_lay` : `~jax.numpy.ndarray`, shape (nlay,) Mean molecular weight per layer in amu. - `vmr_lay` : dict[str, `~jax.numpy.ndarray`] Volume mixing ratios for each species. Keys must match species names in the loaded CK tables. Values can be scalars or arrays with shape (nlay,). - `wl` : `~jax.numpy.ndarray`, shape (nwl,) Wavelength grid in microns. - `ck_mix` : str or int, optional Mixing method selector. Either 'RORR' (default, code=1) or 'PRAS' (code=2). Can be specified as string or integer code. - `g_weights` : `~jax.numpy.ndarray`, optional Quadrature weights for g-point integration. If not provided, retrieved from opacity registry. params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; VMRs come from state['vmr_lay']). Kept for API compatibility with other opacity functions. Returns ------- kappa_ck : `~jax.numpy.ndarray`, shape (nlay, nwl, ng) Total atmospheric mass opacity in cm² g⁻¹ at each layer, wavelength bin, and g-point. """ layer_pressures = state["p_lay"] layer_temperatures = state["T_lay"] layer_mu = state["mu_lay"] layer_count = layer_pressures.shape[0] # Get species names and mixing ratios species_names = XS.ck_runtime_species_order() layer_vmr = state["vmr_lay"] # Direct lookup - species names must match VMR keys exactly # VMR values are already JAX arrays, no need to wrap mixing_ratios = jnp.stack( [jnp.broadcast_to(layer_vmr[name], (layer_count,)) for name in species_names], axis=0, ) g_points, g_weights = _get_ck_quadrature(opac) # Get mixing method from state (default to RORR). # Backwards-compatible: accept either a string ("RORR"/"PRAS") or an int code. ck_mix_raw = state.get("ck_mix", 1) if isinstance(ck_mix_raw, str): ck_mix_code = 2 if ck_mix_raw.upper() == "PRAS" else 1 else: # Avoid int() conversion for JIT compatibility ck_mix_code = ck_mix_raw # Layer-wise interpolation + mixing to avoid materializing (n_species, n_layers, n_wl, n_g). sigma_cube = opac["ck_sigma_cube"] log_p_grid = opac["ck_log10_pressure_grid"] log_temperature_grids = opac["ck_log10_temperature_grids"] log_p_layers = jnp.log10(layer_pressures / bar) log_t_layers = jnp.log10(layer_temperatures) n_species = sigma_cube.shape[0] n_wl = sigma_cube.shape[-2] n_g = sigma_cube.shape[-1] def _interp_sigma_log_layer(log_p: jnp.ndarray, log_t: jnp.ndarray) -> jnp.ndarray: p_idx = jnp.searchsorted(log_p_grid, log_p) - 1 p_idx = jnp.clip(p_idx, 0, log_p_grid.shape[0] - 2) p_weight = (log_p - log_p_grid[p_idx]) / (log_p_grid[p_idx + 1] - log_p_grid[p_idx]) p_weight = jnp.clip(p_weight, 0.0, 1.0) def _interp_one_species(sigma_4d: jnp.ndarray, log_temp_grid: jnp.ndarray) -> jnp.ndarray: t_idx = jnp.searchsorted(log_temp_grid, log_t) - 1 t_idx = jnp.clip(t_idx, 0, log_temp_grid.shape[0] - 2) t_weight = (log_t - log_temp_grid[t_idx]) / (log_temp_grid[t_idx + 1] - log_temp_grid[t_idx]) t_weight = jnp.clip(t_weight, 0.0, 1.0) s_t0_p0 = sigma_4d[t_idx, p_idx, :, :] s_t0_p1 = sigma_4d[t_idx, p_idx + 1, :, :] s_t1_p0 = sigma_4d[t_idx + 1, p_idx, :, :] s_t1_p1 = sigma_4d[t_idx + 1, p_idx + 1, :, :] s_t0 = (1.0 - p_weight) * s_t0_p0 + p_weight * s_t0_p1 s_t1 = (1.0 - p_weight) * s_t1_p0 + p_weight * s_t1_p1 return (1.0 - t_weight) * s_t0 + t_weight * s_t1 return jax.vmap(_interp_one_species)(sigma_cube, log_temperature_grids) def _mix_one_layer(layer_idx: jnp.ndarray, out: jnp.ndarray) -> jnp.ndarray: log_p = log_p_layers[layer_idx] log_t = log_t_layers[layer_idx] sigma_log_layer = _interp_sigma_log_layer(log_p, log_t) # (nspec, nwl, ng) vmr_layer = mixing_ratios[:, layer_idx] # (nspec,) if ck_mix_code == 2: mixed = mix_k_tables_pras( sigma_log_layer[:, None, :, :], vmr_layer[:, None], g_points, g_weights, )[0] else: mixed = mix_k_tables_rorr( 10.0 ** sigma_log_layer[:, None, :, :].astype(jnp.float64), vmr_layer[:, None], g_points, g_weights, )[0] out = out.at[layer_idx].set(mixed) return out mixed_sigma = lax.fori_loop( 0, layer_count, _mix_one_layer, jnp.zeros((layer_count, n_wl, n_g), dtype=jnp.float64), ) # Convert to mass opacity (cm^2 / g) total_opacity = mixed_sigma / (layer_mu[:, None, None] * amu) return total_opacity
[docs] def compute_ck_opacity_perspecies( state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray] ) -> tuple[jnp.ndarray, jnp.ndarray]: """Compute per-species correlated-k opacities WITHOUT mixing. This function is used with the transmission multiplication random overlap method (ck_mix: trans), where species mixing happens during the RT calculation rather than at the opacity computation stage. Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] Atmospheric state dictionary containing: - `p_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer pressures in dyne cm⁻². - `T_lay` : `~jax.numpy.ndarray`, shape (nlay,) Layer temperatures in Kelvin. - `mu_lay` : `~jax.numpy.ndarray`, shape (nlay,) Mean molecular weight per layer in amu. - `vmr_lay` : dict[str, `~jax.numpy.ndarray`] Volume mixing ratios for each species. - `wl` : `~jax.numpy.ndarray`, shape (nwl,) Wavelength grid in microns. params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; kept for API compatibility). Returns ------- sigma_perspecies : `~jax.numpy.ndarray`, shape (n_species, nlay, nwl, ng) Per-species mass opacities in cm² g⁻¹. Note: these are NOT yet weighted by VMR - that happens in the RT calculation. vmr_perspecies : `~jax.numpy.ndarray`, shape (n_species, nlay) Volume mixing ratios for each species at each layer. """ layer_pressures = state["p_lay"] layer_temperatures = state["T_lay"] layer_mu = state["mu_lay"] layer_count = layer_pressures.shape[0] # Get species names and mixing ratios species_names = XS.ck_runtime_species_order() layer_vmr = state["vmr_lay"] # Stack VMRs for all species mixing_ratios = jnp.stack( [jnp.broadcast_to(layer_vmr[name], (layer_count,)) for name in species_names], axis=0, ) # (n_species, nlay) # Get k-table data sigma_cube = opac["ck_sigma_cube"] log_p_grid = opac["ck_log10_pressure_grid"] log_temperature_grids = opac["ck_log10_temperature_grids"] log_p_layers = jnp.log10(layer_pressures / bar) log_t_layers = jnp.log10(layer_temperatures) def _interp_sigma_log_layer(log_p: jnp.ndarray, log_t: jnp.ndarray) -> jnp.ndarray: """Interpolate cross-sections for all species at one layer.""" p_idx = jnp.searchsorted(log_p_grid, log_p) - 1 p_idx = jnp.clip(p_idx, 0, log_p_grid.shape[0] - 2) p_weight = (log_p - log_p_grid[p_idx]) / (log_p_grid[p_idx + 1] - log_p_grid[p_idx]) p_weight = jnp.clip(p_weight, 0.0, 1.0) def _interp_one_species(sigma_4d: jnp.ndarray, log_temp_grid: jnp.ndarray) -> jnp.ndarray: t_idx = jnp.searchsorted(log_temp_grid, log_t) - 1 t_idx = jnp.clip(t_idx, 0, log_temp_grid.shape[0] - 2) t_weight = (log_t - log_temp_grid[t_idx]) / (log_temp_grid[t_idx + 1] - log_temp_grid[t_idx]) t_weight = jnp.clip(t_weight, 0.0, 1.0) s_t0_p0 = sigma_4d[t_idx, p_idx, :, :] s_t0_p1 = sigma_4d[t_idx, p_idx + 1, :, :] s_t1_p0 = sigma_4d[t_idx + 1, p_idx, :, :] s_t1_p1 = sigma_4d[t_idx + 1, p_idx + 1, :, :] s_t0 = (1.0 - p_weight) * s_t0_p0 + p_weight * s_t0_p1 s_t1 = (1.0 - p_weight) * s_t1_p0 + p_weight * s_t1_p1 return (1.0 - t_weight) * s_t0 + t_weight * s_t1 return jax.vmap(_interp_one_species)(sigma_cube, log_temperature_grids) # Interpolate for all layers - returns (nlay, nspec, nwl, ng) then transpose sigma_log_all = jax.vmap(_interp_sigma_log_layer)(log_p_layers, log_t_layers) # sigma_log_all has shape (nlay, nspec, nwl, ng), transpose to (nspec, nlay, nwl, ng) sigma_log_all = jnp.transpose(sigma_log_all, (1, 0, 2, 3)) # Convert from log10 to linear space sigma_linear = 10.0 ** sigma_log_all.astype(jnp.float64) # Convert to mass opacity (cm² / g) # sigma_linear is cross-section (cm² molecule⁻¹) # Divide by (mu * amu) to get mass opacity sigma_perspecies = sigma_linear / (layer_mu[None, :, None, None] * amu) return sigma_perspecies, mixing_ratios