"""
RT_trans_1D_os.py
=================
"""
from __future__ import annotations
from typing import Dict, Mapping, Tuple
import jax.numpy as jnp
from .refraction import maybe_refraction_cutoff_mask
__all__ = ["compute_transit_depth_1d_os"]
def _sum_opacity_components_os(
state: Dict[str, jnp.ndarray],
opacity_components: Mapping[str, jnp.ndarray],
) -> jnp.ndarray:
"""Return the summed opacity grid for all provided components."""
nlay = state["rho_lay"].shape[0]
nwl = state["wl"].shape[0] if "wl" in state else int(state["nwl"])
if not opacity_components:
return jnp.zeros((nlay, nwl))
component_keys = ("line", "rayleigh", "cia", "special", "cloud")
first = next((opacity_components.get(k) for k in component_keys if k in opacity_components), None)
if first is None:
return jnp.zeros((nlay, nwl))
zeros = jnp.zeros_like(first)
stacked = jnp.stack([opacity_components.get(k, zeros) for k in component_keys], axis=0)
return jnp.sum(stacked, axis=0)
def _build_transit_geometry(state: Dict[str, jnp.ndarray]) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Precompute geometry terms for transit depth calculation."""
R0 = state["R0"]
z_lev = state["z_lev"]
z_lay = state["z_lay"]
r_mid = R0 + z_lay
r_low = R0 + z_lev[:-1]
r_up = R0 + z_lev[1:]
dr = r_up - r_low
r_mid_2d = r_mid[:, None]
r_up_2d = r_up[None, :]
r_low_2d = r_low[None, :]
dr_2d = dr[None, :]
sqrt_up = jnp.sqrt(jnp.maximum(r_up_2d**2 - r_mid_2d**2, 0.0))
sqrt_low = jnp.sqrt(jnp.maximum(r_low_2d**2 - r_mid_2d**2, 0.0))
P_case1 = jnp.zeros_like(sqrt_up)
P_case2 = 2.0 / dr_2d * sqrt_up
P_case3 = 2.0 / dr_2d * (sqrt_up - sqrt_low)
cond1 = r_up_2d <= r_mid_2d
cond2 = (r_low_2d <= r_mid_2d) & (r_mid_2d < r_up_2d)
P1D = jnp.where(cond1, P_case1, jnp.where(cond2, P_case2, P_case3))
area_weight = 2.0 * r_mid * dr
return P1D, area_weight
def _get_base_transit_radius(state: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Return the opaque baseline radius at the deepest level in the grid."""
return state["R0"] + state["z_lev"][0]
def _transit_depth_and_contrib_from_opacity(
state: Dict[str, jnp.ndarray],
k_tot: jnp.ndarray, # (nlay, nwl)
geometry: tuple[jnp.ndarray, jnp.ndarray],
refraction_mask: jnp.ndarray | None,
want_contrib: bool,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
R_base = _get_base_transit_radius(state)
R_s = state["R_s"]
rho = state["rho_lay"]
dz = state["dz"]
P1D, area_weight = geometry
k_eff = jnp.maximum(k_tot, 1.0e-99)
dtau_v = k_eff * rho[:, None] * dz[:, None]
tau_path = jnp.matmul(P1D, dtau_v)
if refraction_mask is not None:
tau_path = jnp.where(refraction_mask, 1.0e30, tau_path)
one_minus_trans = 1.0 - jnp.exp(-tau_path)
dR2_i = area_weight[:, None] * one_minus_trans
dR2 = jnp.sum(dR2_i, axis=0)
D = (R_base**2 + dR2) / (R_s**2)
if not want_contrib:
layer_dR2 = jnp.zeros_like(dtau_v)
return D, dR2, layer_dR2
tau_eps = 1.0e-30
ratio = jnp.where(tau_path > tau_eps, one_minus_trans / tau_path, 1.0)
W = area_weight[:, None] * ratio
geom_weighted = jnp.matmul(P1D.T, W)
layer_dR2 = dtau_v * geom_weighted
return D, dR2, layer_dR2
def _transit_depth_from_opacity(
state: Dict[str, jnp.ndarray],
k_tot: jnp.ndarray, # (nlay, nwl)
geometry: tuple[jnp.ndarray, jnp.ndarray],
refraction_mask: jnp.ndarray | None,
) -> jnp.ndarray:
R_base = _get_base_transit_radius(state)
R_s = state["R_s"]
rho = state["rho_lay"]
dz = state["dz"]
P1D, area_weight = geometry
k_eff = jnp.maximum(k_tot, 1.0e-99)
dtau_v = k_eff * rho[:, None] * dz[:, None]
tau_path = jnp.matmul(P1D, dtau_v)
if refraction_mask is not None:
tau_path = jnp.where(refraction_mask, 1.0e30, tau_path)
one_minus_trans = 1.0 - jnp.exp(-tau_path)
dR2 = jnp.sum(area_weight[:, None] * one_minus_trans, axis=0)
return (R_base**2 + dR2) / (R_s**2)
[docs]
def compute_transit_depth_1d_os(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
opacity_components: Mapping[str, jnp.ndarray],
opac: Mapping[str, jnp.ndarray] | None = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
contri_func = state.get("contri_func", False)
refraction_mask = maybe_refraction_cutoff_mask(state, params, opac)
geometry = _build_transit_geometry(state)
k_tot = _sum_opacity_components_os(state, opacity_components) # (nlay, nwl)
if "f_cloud" in params and "cloud" in opacity_components:
f_cloud = jnp.clip(params["f_cloud"], 0.0, 1.0)
k_no_cloud = k_tot - opacity_components["cloud"]
if contri_func:
D_cloud, dR2_cloud, layer_dR2_cloud = _transit_depth_and_contrib_from_opacity(
state, k_tot, geometry=geometry, refraction_mask=refraction_mask, want_contrib=True
)
D_clear, dR2_clear, layer_dR2_clear = _transit_depth_and_contrib_from_opacity(
state, k_no_cloud, geometry=geometry, refraction_mask=refraction_mask, want_contrib=True
)
D_net = f_cloud * D_cloud + (1.0 - f_cloud) * D_clear
dR2 = f_cloud * dR2_cloud + (1.0 - f_cloud) * dR2_clear
layer_dR2 = f_cloud * layer_dR2_cloud + (1.0 - f_cloud) * layer_dR2_clear
contrib_func_norm = layer_dR2 / jnp.maximum(dR2[None, :], 1e-30)
else:
D_cloud = _transit_depth_from_opacity(state, k_tot, geometry=geometry, refraction_mask=refraction_mask)
D_clear = _transit_depth_from_opacity(state, k_no_cloud, geometry=geometry, refraction_mask=refraction_mask)
D_net = f_cloud * D_cloud + (1.0 - f_cloud) * D_clear
contrib_func_norm = jnp.zeros((state["dz"].shape[0], D_net.shape[0]), dtype=D_net.dtype)
else:
if contri_func:
D_net, dR2, layer_dR2 = _transit_depth_and_contrib_from_opacity(
state, k_tot, geometry=geometry, refraction_mask=refraction_mask, want_contrib=True
)
contrib_func_norm = layer_dR2 / jnp.maximum(dR2[None, :], 1e-30)
else:
D_net = _transit_depth_from_opacity(state, k_tot, geometry=geometry, refraction_mask=refraction_mask)
contrib_func_norm = jnp.zeros((state["dz"].shape[0], D_net.shape[0]), dtype=D_net.dtype)
return D_net, contrib_func_norm