Source code for exo_skryer.RT_trans_1D_ck_trans

"""
RT_trans_1D_ck_trans.py
=======================

Transit transmission spectrum calculation using the transmission multiplication
random overlap method for correlated-k species mixing.

This module differs from RT_trans_1D_ck.py in that species are combined by
multiplying their mean transmissions under the random-overlap assumption:

    T_total = exp(-tau_cont) * Π_s [ Σ_g w_g exp(-tau_s(g)) ]

This avoids ROM sorting / k-distribution mixing entirely and is intended as a
fast transmission-only approximation.
"""

from __future__ import annotations

from typing import Dict, Mapping, Tuple

import jax.numpy as jnp
from jax import lax

from .refraction import maybe_refraction_cutoff_mask
from .RT_trans_1D_os import _get_base_transit_radius

__all__ = ["compute_transit_depth_1d_ck_trans"]


def _get_ck_quadrature(opac):
    """Extract g-points and weights from opac cache."""
    g_points_all = opac["g_points"]
    g_weights = opac["g_weights"]

    if g_points_all.ndim == 1:
        g_points = g_points_all
    else:
        g_points = g_points_all[0]

    if g_weights.ndim > 1:
        g_weights = g_weights[0]

    return g_points, g_weights


def _build_transit_geometry(state: Dict[str, jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Precompute geometry terms for transit depth calculation."""
    R0 = state["R0"]
    z_lev = state["z_lev"]
    z_lay = state["z_lay"]

    r_mid = R0 + z_lay
    r_low = R0 + z_lev[:-1]
    r_up = R0 + z_lev[1:]
    dr = r_up - r_low

    r_mid_2d = r_mid[:, None]
    r_up_2d = r_up[None, :]
    r_low_2d = r_low[None, :]
    dr_2d = dr[None, :]

    sqrt_up = jnp.sqrt(jnp.maximum(r_up_2d**2 - r_mid_2d**2, 0.0))
    sqrt_low = jnp.sqrt(jnp.maximum(r_low_2d**2 - r_mid_2d**2, 0.0))

    P_case1 = jnp.zeros_like(sqrt_up)
    P_case2 = 2.0 / dr_2d * sqrt_up
    P_case3 = 2.0 / dr_2d * (sqrt_up - sqrt_low)

    cond1 = r_up_2d <= r_mid_2d
    cond2 = (r_low_2d <= r_mid_2d) & (r_mid_2d < r_up_2d)

    P1D = jnp.where(cond1, P_case1, jnp.where(cond2, P_case2, P_case3))
    area_weight = 2.0 * r_mid * dr

    return P1D, area_weight


def _sum_opacity_components_2d(
    state: Dict[str, jnp.ndarray],
    opacity_components: Mapping[str, jnp.ndarray],
) -> jnp.ndarray:
    """Sum 2D opacity components (rayleigh, cia, special, cloud).

    Returns shape (nlay, nwl).
    """
    nlay = state["rho_lay"].shape[0]
    nwl = state["wl"].shape[0] if "wl" in state else int(state["nwl"])
    zeros_2d = jnp.zeros((nlay, nwl), dtype=state["rho_lay"].dtype)

    component_keys = ("rayleigh", "cia", "special", "cloud")
    components = jnp.stack([opacity_components.get(k, zeros_2d) for k in component_keys], axis=0)
    return jnp.sum(components, axis=0)


def _integrate_g_points_trans(
    sigma_perspecies: jnp.ndarray,  # (n_species, nlay, nwl, ng)
    vmr_perspecies: jnp.ndarray,    # (n_species, nlay)
    other_opacity_2d: jnp.ndarray,  # (nlay, nwl) - rayleigh, CIA, etc.
    g_points: jnp.ndarray,          # (ng,)
    g_weights: jnp.ndarray,         # (ng,)
    state: Dict[str, jnp.ndarray],
    geometry: tuple[jnp.ndarray, jnp.ndarray],
    refraction_mask: jnp.ndarray | None,
    want_contrib: bool,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Integrate transit depth using RO transmission-function multiplication.

    This computes the slant-path transmission per impact parameter and wavelength as:
        T_total = exp(-tau_cont) * Π_s <exp(-tau_s)>_g
    where <...>_g is the g-quadrature average at fixed (impact parameter, wavelength).
    """
    R_base = _get_base_transit_radius(state)
    R_s = state["R_s"]
    rho = state["rho_lay"]
    dz = state["dz"]
    P1D, area_weight = geometry

    nlay = state["rho_lay"].shape[0]
    nwl = other_opacity_2d.shape[1]

    del g_points

    if want_contrib:
        raise NotImplementedError(
            "Contribution functions are not implemented for ck_mix=TRANS "
            "(RO transmission-product approximation)."
        )

    scale = rho * dz  # (nlay,)

    # Continuum-like opacity (2D) -> slant optical depth (nlay, nwl)
    k_cont = jnp.maximum(other_opacity_2d, 0.0)
    dtau_v_cont = k_cont * scale[:, None]  # (nlay, nwl)
    tau_path_cont = jnp.einsum("ij,jw->iw", P1D, dtau_v_cont)  # (nlay, nwl)

    nspec = sigma_perspecies.shape[0]
    ng = sigma_perspecies.shape[-1]
    w = g_weights[:ng].astype(sigma_perspecies.dtype)

    # Multiply mean transmissions over species at each (impact parameter, wavelength)
    T_prod0 = jnp.ones((nlay, nwl), dtype=sigma_perspecies.dtype)

    def _body(spec_idx: int, T_prod: jnp.ndarray) -> jnp.ndarray:
        kappa_s = sigma_perspecies[spec_idx]  # (nlay, nwl, ng)
        vmr_s = vmr_perspecies[spec_idx]      # (nlay,)

        dtau_v_s = kappa_s * vmr_s[:, None, None] * scale[:, None, None]  # (nlay, nwl, ng)
        tau_path_s = jnp.einsum("ij,jwg->iwg", P1D, dtau_v_s)             # (nlay, nwl, ng)
        tau_path_s = jnp.clip(tau_path_s, 0.0, 300.0)
        T_s_g = jnp.exp(-tau_path_s)                                      # (nlay, nwl, ng)
        T_s = jnp.sum(T_s_g * w[None, None, :], axis=-1)                  # (nlay, nwl)
        return T_prod * jnp.clip(T_s, 1e-99, 1.0)

    # Important: when nspec==0 (e.g. continuum-only diagnostics), tracing a fori_loop
    # would still stage `_body` and attempt `sigma_perspecies[0]`, which is invalid
    # for a zero-length leading dimension.
    if nspec == 0:
        T_prod = T_prod0
    else:
        T_prod = lax.fori_loop(0, nspec, _body, T_prod0)  # (nlay, nwl)

    T_total = jnp.exp(-jnp.clip(tau_path_cont, 0.0, 300.0)) * T_prod  # (nlay, nwl)
    if refraction_mask is not None:
        T_total = jnp.where(refraction_mask, 0.0, T_total)
    one_minus_trans = 1.0 - jnp.clip(T_total, 0.0, 1.0)               # (nlay, nwl)

    dR2 = jnp.sum(area_weight[:, None] * one_minus_trans, axis=0)     # (nwl,)
    D_net = (R_base**2 + dR2) / (R_s**2)

    layer_dR2 = jnp.zeros((nlay, nwl), dtype=D_net.dtype)
    return D_net, dR2, layer_dR2


[docs] def compute_transit_depth_1d_ck_trans( state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray], opacity_components: Mapping[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute 1D transit depth using transmission multiplication random overlap. This function expects per-species opacities in opacity_components: - 'line_perspecies': (n_species, nlay, nwl, ng) per-species mass opacities - 'vmr_perspecies': (n_species, nlay) volume mixing ratios Parameters ---------- state : dict Atmospheric state dictionary. params : dict Parameter dictionary (may contain 'f_cloud'). opacity_components : dict Opacity components including 'line_perspecies', 'vmr_perspecies', and optionally 'rayleigh', 'cia', 'special', 'cloud'. Returns ------- D_net : array, shape (nwl,) Transit depth spectrum. contrib_func : array, shape (nlay, nwl) Normalized contribution function. """ contri_func = state.get("contri_func", False) refraction_mask = maybe_refraction_cutoff_mask(state, params, opac) geometry = _build_transit_geometry(state) g_points, g_weights = _get_ck_quadrature(opac) # Get per-species opacities sigma_perspecies = opacity_components.get("line_perspecies") vmr_perspecies = opacity_components.get("vmr_perspecies") if sigma_perspecies is None or vmr_perspecies is None: raise ValueError( "compute_transit_depth_1d_ck_trans requires 'line_perspecies' and " "'vmr_perspecies' in opacity_components. Use ck_mix: trans in config." ) # Sum 2D opacity components other_opacity_2d = _sum_opacity_components_2d(state, opacity_components) # Handle cloud fraction if present if "f_cloud" in params and "cloud" in opacity_components: f_cloud = jnp.clip(params["f_cloud"], 0.0, 1.0) cloud_component = opacity_components["cloud"] # With clouds other_with_cloud = other_opacity_2d # Without clouds other_no_cloud = jnp.maximum(other_opacity_2d - cloud_component, 0.0) if contri_func: D_cloud, dR2_cloud, layer_dR2_cloud = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_with_cloud, g_points, g_weights, state, geometry, refraction_mask, want_contrib=True ) D_clear, dR2_clear, layer_dR2_clear = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_no_cloud, g_points, g_weights, state, geometry, refraction_mask, want_contrib=True ) D_net = f_cloud * D_cloud + (1.0 - f_cloud) * D_clear dR2 = f_cloud * dR2_cloud + (1.0 - f_cloud) * dR2_clear layer_dR2 = f_cloud * layer_dR2_cloud + (1.0 - f_cloud) * layer_dR2_clear contrib_func_norm = layer_dR2 / jnp.maximum(dR2[None, :], 1e-30) else: D_cloud, _, _ = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_with_cloud, g_points, g_weights, state, geometry, refraction_mask, want_contrib=False ) D_clear, _, _ = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_no_cloud, g_points, g_weights, state, geometry, refraction_mask, want_contrib=False ) D_net = f_cloud * D_cloud + (1.0 - f_cloud) * D_clear contrib_func_norm = jnp.zeros((state["dz"].shape[0], D_net.shape[0]), dtype=D_net.dtype) else: # No cloud fraction handling if contri_func: D_net, dR2, layer_dR2 = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_opacity_2d, g_points, g_weights, state, geometry, refraction_mask, want_contrib=True ) contrib_func_norm = layer_dR2 / jnp.maximum(dR2[None, :], 1e-30) else: D_net, _, _ = _integrate_g_points_trans( sigma_perspecies, vmr_perspecies, other_opacity_2d, g_points, g_weights, state, geometry, refraction_mask, want_contrib=False ) contrib_func_norm = jnp.zeros((state["dz"].shape[0], D_net.shape[0]), dtype=D_net.dtype) return D_net, contrib_func_norm