Source code for exo_skryer.ck_mix_PRAS

"""
ck_mix_PRAS.py
=================
"""

import jax
import jax.numpy as jnp

from .aux_functions import latin_hypercube, pchip_1d

__all__ = [
    "mix_k_tables_pras",
]


def _pras_mix_band(
    sigma_stack_log: jnp.ndarray,
    vmr_layer: jnp.ndarray,
    g_points: jnp.ndarray,
    base_weights: jnp.ndarray,
    key: jax.Array,
) -> jnp.ndarray:
    del base_weights

    n_samples = g_points.shape[0] * g_points.shape[0]
    n_species = sigma_stack_log.shape[0]

    gs_mat, _ = latin_hypercube(key, n_samples, n_species, scramble=True, dtype=jnp.float64)
    u_by_species = gs_mat.T  # (n_species, n_samples)

    sigma_stack_log = jnp.maximum(sigma_stack_log, -99.0)

    def _eval_one(log_sigma_g: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray:
        return pchip_1d(u, g_points, log_sigma_g)

    log_samp = jax.vmap(_eval_one, in_axes=(0, 0))(sigma_stack_log, u_by_species)  # (n_species, n_samples)
    cs_samp = vmr_layer[:, None] * (10.0 ** log_samp)
    cs_tot = jnp.sum(cs_samp, axis=0)
    cs_tot = jnp.maximum(cs_tot, 1e-99)

    cs_sorted = jnp.sort(cs_tot)
    g_mid = (jnp.arange(n_samples, dtype=jnp.float64) + 0.5) / jnp.asarray(n_samples, dtype=jnp.float64)
    return jnp.interp(g_points, g_mid, cs_sorted)


[docs] def mix_k_tables_pras( sigma_values_log: jnp.ndarray, mixing_ratios: jnp.ndarray, g_points: jnp.ndarray, base_weights: jnp.ndarray, ) -> jnp.ndarray: n_species, n_layers, n_wl, n_g = sigma_values_log.shape dtype = sigma_values_log.dtype if n_species == 0: return jnp.zeros((n_layers, n_wl, n_g), dtype=dtype) if mixing_ratios.ndim == 1: mixing_ratios = jnp.broadcast_to(mixing_ratios[:, None], (n_species, n_layers)) base_key = jax.random.PRNGKey(0) wl_indices = jnp.arange(n_wl) def _mix_one_layer(layer_index: jnp.ndarray) -> jnp.ndarray: vmr_layer = mixing_ratios[:, layer_index] key_layer = jax.random.fold_in(base_key, layer_index) def _mix_one_wl(wl_index): sigma_band = sigma_values_log[:, layer_index, wl_index, :] key = jax.random.fold_in(key_layer, wl_index) return _pras_mix_band(sigma_band, vmr_layer, g_points, base_weights, key) # Parallelize over wavelengths: each bin is independent with unique PRNG key mixed_by_wl = jax.vmap(_mix_one_wl)(wl_indices) # (nwl, ng) return mixed_by_wl layer_indices = jnp.arange(n_layers) return jax.vmap(_mix_one_layer, in_axes=0)(layer_indices)