"""
refraction.py
=============
Approximate refraction support for transmission spectroscopy.
Current implementation: "cutoff" mode (option A) that applies a refractive
boundary (fully opaque below a wavelength-dependent impact parameter) without
curved-ray optical-depth integration.
"""
from __future__ import annotations
from typing import Dict
import jax.numpy as jnp
from jax import lax
from . import registry_ray as XR
from .data_constants import AU, kb, amu
__all__ = ["refraction_cutoff_mask", "maybe_refraction_cutoff_mask"]
[docs]
def refraction_cutoff_mask(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
opac: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
"""Return a boolean mask for impact parameters blocked by refraction.
The mask is defined on the same "impact parameter grid" used by the current
transit RT kernels: `b ≈ R0 + z_lay` (layer midpoints).
For each (layer, wavelength), we estimate the bending angle using the
exponential-atmosphere approximation:
alpha(b, λ) ≈ (n(b, λ) - 1) * sqrt(2π b / H)
and mark a ray as blocked if:
alpha(b, λ) > theta_star, theta_star = asin(R_s / a)
where `a = a_sm * AU`.
Parameters
----------
state : dict
Must contain `R0`, `R_s`, `z_lay`, `T_lay`, `mu_lay`, `nd_lay`, and `vmr_lay`.
params : dict
Must contain `log_10_g` and `a_sm` (AU).
opac : dict
Must contain `ray_refractivity_coeff_table` aligned with
`registry_ray.ray_species_names()`.
Returns
-------
mask : jnp.ndarray, shape (nlay, nwl), dtype bool
True where refraction blocks stellar rays (treat as fully opaque).
"""
nwl = state["wl"].shape[0]
nlay = state["nd_lay"].shape[0]
if "ray_refractivity_coeff_table" not in opac:
raise RuntimeError("Refraction requested but Rayleigh refractivity tables are missing from opac cache.")
# Stellar angular radius at the planet
a_cm = params["a_sm"] * AU
theta_star = jnp.arcsin(jnp.clip(state["R_s"] / a_cm, 0.0, 1.0)) # scalar
# Scale height at each layer midpoint (include spherical gravity correction)
R0 = state["R0"]
z_lay = state["z_lay"]
b = R0 + z_lay # (nlay,)
g0 = 10.0 ** params["log_10_g"]
g_z = g0 * (R0 / b) ** 2
H = (kb * state["T_lay"]) / (state["mu_lay"] * amu * g_z) # (nlay,)
# Build (n-1)(layer, wl) from STP refractivities + ideal-gas scaling with number density.
refractivity_coeff = opac["ray_refractivity_coeff_table"] # (nspec, nwl)
species_names = XR.ray_runtime_species_order()
vmr_lay = state["vmr_lay"]
nd_lay = state["nd_lay"] # (nlay,)
mixing_ratios = jnp.stack(
[jnp.broadcast_to(vmr_lay[name], (nlay,)) for name in species_names],
axis=0,
) # (n_species, nlay)
nm1_coeff = jnp.einsum("sl,sw->lw", mixing_ratios, refractivity_coeff)
nm1_layer = nd_lay[:, None] * nm1_coeff # (nlay, nwl)
nm1_layer = jnp.maximum(nm1_layer, 0.0)
# Exponential-atmosphere bending approximation
alpha = nm1_layer * jnp.sqrt(2.0 * jnp.pi * b[:, None] / jnp.maximum(H[:, None], 1.0))
return alpha > theta_star
[docs]
def maybe_refraction_cutoff_mask(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
opac: Dict[str, jnp.ndarray] | None,
) -> jnp.ndarray:
"""Return a JAX-safe refraction mask or an all-false mask.
This avoids Python-side branching on traced `refraction_mode` values inside
jitted/vmapped transit kernels.
"""
if "z_lay" in state:
nlay = state["z_lay"].shape[0]
elif "dz" in state:
nlay = state["dz"].shape[0]
else:
nlay = int(state["nlay"])
if "wl" in state:
nwl = state["wl"].shape[0]
else:
nwl = int(state["nwl"])
zeros = jnp.zeros((nlay, nwl), dtype=bool)
if opac is None:
return zeros
if "ray_refractivity_coeff_table" not in opac:
return zeros
required_state = ("wl", "T_lay", "mu_lay", "nd_lay", "vmr_lay", "z_lay", "R0", "R_s")
required_params = ("log_10_g", "a_sm")
if any(name not in state for name in required_state):
return zeros
if any(name not in params for name in required_params):
return zeros
refraction_mode = jnp.asarray(state.get("refraction_mode", 0))
return lax.cond(
refraction_mode == 1,
lambda _: refraction_cutoff_mask(state, params, opac),
lambda _: zeros,
operand=None,
)