Source code for exo_skryer.ck_mix_RORR

"""
ck_mix_RORR.py
==============
"""

import jax
import jax.numpy as jnp
from jax import lax


__all__ = ['mix_k_tables_rorr']

def _rom_mix_band(
    sigma_stack: jnp.ndarray,
    vmr_layer: jnp.ndarray,
    g_points: jnp.ndarray,
    base_weights: jnp.ndarray,
    rom_weights: jnp.ndarray,
) -> jnp.ndarray:
    """Mix multiple species at one (layer, wavelength) using RORR.

    Implements random overlap with resampling and reordering (RORR): sequentially add
    species by forming the ROM matrix, sorting by k, and interpolating back to
    the standard g-grid.

    Parameters
    ----------
    sigma_stack : `~jax.numpy.ndarray`
        Cross-sections for all species at this (layer, wavelength),
        shape `(n_species, n_g)`.
    vmr_layer : `~jax.numpy.ndarray`
        Volume mixing ratios for each species at this layer, shape `(n_species,)`.
    g_points : `~jax.numpy.ndarray`
        Standard g-points for interpolation, shape `(n_g,)`.
    base_weights : `~jax.numpy.ndarray`
        Quadrature weights for each g-point, shape `(n_g,)`.
    rom_weights : `~jax.numpy.ndarray`
        Pre-computed ROM weights (outer product of `base_weights`), shape `(n_g**2,)`.

    Returns
    -------
    `~jax.numpy.ndarray`
        Mixed cross-section at this (layer, wavelength), shape `(n_g,)`.
    """
    n_species = sigma_stack.shape[0]
    ng = sigma_stack.shape[-1]

    if n_species == 0:
        return jnp.zeros(ng, dtype=sigma_stack.dtype)

    # Initialize with first species
    vmr_tot = vmr_layer[0]
    cs_mix = sigma_stack[0] * vmr_tot

    if n_species == 1:
        return cs_mix

    # NOTE: rom_weights is pre-computed outside and passed in as parameter

    def body(carry, inputs):
        cs_mix_prev, vmr_tot_prev = carry
        sigma_spec, vmr_spec = inputs

        def skip_species(_):
            """Skip species with negligible cross-section."""
            vmr_tot = vmr_tot_prev + vmr_spec
            return (cs_mix_prev, vmr_tot), None

        def mix_species(_):
            """Perform RORR mixing for this species."""
            vmr_tot = vmr_tot_prev + vmr_spec

            # Create ROM matrix: k_rom_matrix[i,j] = (cs_mix[i] + vmr*sigma[j]) / vmr_tot
            k_rom_matrix = (cs_mix_prev[:, None] + vmr_spec * sigma_spec[None, :]) / vmr_tot

            # Flatten
            k_rom_flat = k_rom_matrix.ravel()

            # OPTIMIZATION: Sort pairs directly instead of argsort + fancy indexing
            k_rom_sorted, w_rom_sorted = lax.sort_key_val(k_rom_flat, rom_weights)
            k_rom_sorted = jnp.clip(k_rom_sorted, 1e-99, None)

            # Compute cumulative g with optimized normalization
            w_cumsum = jnp.cumsum(w_rom_sorted)
            g_rom = w_cumsum / w_cumsum[-1]

            # OPTIMIZATION: Cleaner log/exp operations using 10** for clarity
            log_k_interp = jnp.interp(g_points, g_rom, jnp.log10(k_rom_sorted))
            cs_mix_new = vmr_tot * (10.0 ** log_k_interp)

            return (cs_mix_new, vmr_tot), None

        # Skip if max cross-section is negligible (< 1e-50)
        return lax.cond(
            jnp.max(sigma_spec) < 1e-50,
            skip_species,
            mix_species,
            operand=None
        )

    # Scan over species 1 onwards
    (cs_mix_final, _), _ = lax.scan(
        body,
        (cs_mix, vmr_tot),
        (sigma_stack[1:], vmr_layer[1:])
    )

    return cs_mix_final


[docs] def mix_k_tables_rorr( sigma_values: jnp.ndarray, mixing_ratios: jnp.ndarray, g_points: jnp.ndarray, base_weights: jnp.ndarray, ) -> jnp.ndarray: """Mix correlated-k tables across species using RORR. Parameters ---------- sigma_values : `~jax.numpy.ndarray` Cross-sections for all species, shape `(n_species, n_layers, n_wavelength, n_g)`. mixing_ratios : `~jax.numpy.ndarray` Volume mixing ratios, shape `(n_species, n_layers)` or `(n_species,)`. If 1D, it is broadcast across layers. g_points : `~jax.numpy.ndarray` Standard g-points for interpolation, shape `(n_g,)`. base_weights : `~jax.numpy.ndarray` Quadrature weights, shape `(n_g,)`. Returns ------- `~jax.numpy.ndarray` Mixed cross-sections, shape `(n_layers, n_wavelength, n_g)`. """ n_species, n_layers, n_wl, n_g = sigma_values.shape dtype = sigma_values.dtype if n_species == 0: return jnp.zeros((n_layers, n_wl, n_g), dtype=dtype) if mixing_ratios.ndim == 1: mixing_ratios = jnp.broadcast_to(mixing_ratios[:, None], (n_species, n_layers)) # OPTIMIZATION: Pre-compute ROM weights ONCE (not n_layers * n_wl times!) rom_weights = jnp.outer(base_weights, base_weights).reshape(-1) wl_indices = jnp.arange(n_wl) def _mix_one_layer(layer_idx: jnp.ndarray) -> jnp.ndarray: vmr_layer = mixing_ratios[:, layer_idx] def _mix_one_wl(wl_idx): sigma_band = sigma_values[:, layer_idx, wl_idx, :] return _rom_mix_band(sigma_band, vmr_layer, g_points, base_weights, rom_weights) # Parallelize over wavelengths (independent). Species mixing inside _rom_mix_band # remains sequential (lax.scan) due to running mixture dependency. mixed_by_wl = jax.vmap(_mix_one_wl)(wl_indices) # (nwl, ng) return mixed_by_wl layer_indices = jnp.arange(n_layers) return jax.vmap(_mix_one_layer, in_axes=0)(layer_indices)