"""
opacity_cloud.py
================
"""
from typing import Dict, Tuple, Optional
import jax
import jax.numpy as jnp
from .aux_functions import pchip_1d
from .mie_schemes import rayleigh, madt
from .lxmie_mod import lxmie_jax
__all__ = [
"compute_cloud_efficiencies",
"compute_cloud_opacity",
"zero_cloud_opacity",
"grey_const_cloud",
"grey_profile_cloud",
"deck_and_powerlaw",
"F18_cloud",
"direct_nk",
"nk_f18_blend",
"f18_skew_cloud",
]
_LXMIE_NMAX = 2000
_LXMIE_CF_MAX_TERMS = 2000
_LXMIE_CF_EPS = 1e-10
_DIV_EPS = 1e-30
_QC_EPS = 1e-30
_Q_EXT_MAX = 4.0
def _safe_div(num: jnp.ndarray, den: jnp.ndarray) -> jnp.ndarray:
"""Elementwise num/den with 0 where den==0 (avoids NaNs in cloud-free layers)."""
return jnp.where(den != 0, num / den, 0.0)
[docs]
def compute_cloud_efficiencies(
wl: jnp.ndarray,
r_cm: jnp.ndarray,
params: Dict[str, jnp.ndarray],
*,
eff_scheme: str,
n: Optional[jnp.ndarray] = None,
k: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Second-stage landing function for (Q_ext, Q_sca, g).
Parameters
----------
wl : `~jax.numpy.ndarray`, shape (nwl,)
Wavelength grid in microns.
r_cm : `~jax.numpy.ndarray`, shape (nr,) or scalar
Particle radius in cm.
params : dict[str, `~jax.numpy.ndarray`]
Parameters needed by the selected scheme.
n, k : `~jax.numpy.ndarray`, optional
Refractive-index arrays on the same wl grid (shape (nwl,)). For all nk-based
schemes, these must be provided (the "physics" pathway uses cached n,k from
the registry/opac_cache). The node-interpolation pathway is kept only in
`direct_nk`.
eff_scheme : str
Efficiency scheme identifier. Current options:
- "f18": Fisher & Heng (2018) Qext model (Qsca=0, g=0 for now)
- "mie_madt": Rayleigh + MADT blend using retrieved n,k nodes
- "lxmie": full Lorenz-Mie (LX-MIE) using retrieved n,k nodes
Returns
-------
Q_ext, Q_sca, g : arrays
Efficiencies and asymmetry parameter. Shape is (nr, nwl) if r_cm is a
vector, else (nwl,) for scalar radius.
"""
scheme = eff_scheme.lower().strip()
if scheme in ("f18", "fisher18", "fisher_heng"):
return F18_cloud(wl, r_cm, params)
# nk-based schemes: physics pathway only (cached n,k provided by caller).
if n is None or k is None:
raise ValueError(
"compute_cloud_efficiencies: n and k must be provided for nk-based schemes "
"(use cached n,k from the registry/opac_cache)."
)
wl_support_mask = jnp.ones_like(wl, dtype=bool)
if scheme in ("mie_madt", "madt", "rayleigh_madt"):
return _efficiencies_mie_madt(wl, r_cm, n, k, wl_support_mask)
if scheme in ("lxmie", "mie", "mie_full", "full_mie"):
return _efficiencies_lxmie(
wl,
r_cm,
n,
k,
wl_support_mask,
nmax=_LXMIE_NMAX,
cf_max_terms=_LXMIE_CF_MAX_TERMS,
cf_eps=_LXMIE_CF_EPS,
)
raise ValueError(
f"Unknown eff_scheme='{eff_scheme}'. "
"Valid options: f18, mie_madt, lxmie."
)
def _efficiencies_mie_madt(
wl: jnp.ndarray,
r_cm: jnp.ndarray,
n: jnp.ndarray,
k: jnp.ndarray,
wl_support_mask: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Rayleigh + MADT blend on a (r, wl) grid using broadcasting."""
r_um = r_cm * 1e4
if r_um.ndim == 0:
x = 2.0 * jnp.pi * r_um / wl
else:
x = 2.0 * jnp.pi * r_um[:, None] / wl[None, :]
Q_ext_ray, Q_sca_ray, g_ray = rayleigh(n, k, x)
Q_ext_madt, Q_sca_madt, g_madt = madt(n, k, x)
# Smooth blend between Rayleigh (x=1.0) and MADT (x=3.0) using smootherstep
t = jnp.clip((x - 1.0) / 2.0, 0.0, 1.0)
w = 6.0 * t**5 - 15.0 * t**4 + 10.0 * t**3
Q_ext = (1.0 - w) * Q_ext_ray + w * Q_ext_madt
Q_sca = (1.0 - w) * Q_sca_ray + w * Q_sca_madt
g = (1.0 - w) * g_ray + w * g_madt
# Mask out wavelengths outside the nk node span.
if r_um.ndim == 0:
Q_ext = jnp.where(wl_support_mask, Q_ext, 0.0)
Q_sca = jnp.where(wl_support_mask, Q_sca, 0.0)
g = jnp.where(wl_support_mask, g, 0.0)
else:
Q_ext = jnp.where(wl_support_mask[None, :], Q_ext, 0.0)
Q_sca = jnp.where(wl_support_mask[None, :], Q_sca, 0.0)
g = jnp.where(wl_support_mask[None, :], g, 0.0)
return Q_ext, Q_sca, g
def _efficiencies_lxmie(
wl: jnp.ndarray,
r_cm: jnp.ndarray,
n: jnp.ndarray,
k: jnp.ndarray,
wl_support_mask: jnp.ndarray,
*,
nmax: int,
cf_max_terms: int,
cf_eps: float,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Full Mie efficiencies (LX-MIE) on a (r, wl) grid."""
r_um = r_cm * 1e4
ri = (n + 1.0j * k).astype(jnp.complex128) # lxmie_mod expects n + i k
def _one_radius(r_um_val: jnp.ndarray):
x_wl = (2.0 * jnp.pi * r_um_val) / wl
def _one_wl(x_val, ri_val, in_support):
def do():
q_ext, q_sca, _q_abs, g = lxmie_jax(
ri_val, x_val, nmax=nmax, cf_max_terms=cf_max_terms, cf_eps=cf_eps
)
return q_ext, q_sca, g
def skip():
z = jnp.zeros_like(x_val)
return z, z, z
return jax.lax.cond(in_support, do, skip)
Q_ext, Q_sca, g = jax.vmap(_one_wl, in_axes=(0, 0, 0))(x_wl, ri, wl_support_mask)
return Q_ext, Q_sca, g
if r_um.ndim == 0:
return _one_radius(r_um)
Q_ext, Q_sca, g = jax.vmap(_one_radius, in_axes=(0,))(r_um)
return Q_ext, Q_sca, g
def _compute_mie_madt_efficiencies(
wl_val: jnp.ndarray,
n_val: jnp.ndarray,
k_val: jnp.ndarray,
r_eff: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute extinction and scattering efficiencies using Rayleigh + MADT blend.
This function computes Q_ext, Q_sca, and g using a smooth blend between
Rayleigh scattering (small particles) and Modified Anomalous Diffraction
Theory (MADT, large particles) based on the size parameter x.
Parameters
----------
wl_val : `~jax.numpy.ndarray`
Wavelength in microns.
n_val : `~jax.numpy.ndarray`
Real part of the refractive index.
k_val : `~jax.numpy.ndarray`
Imaginary part of the refractive index.
r_eff : `~jax.numpy.ndarray`
Effective particle radius in microns.
Returns
-------
Q_ext : `~jax.numpy.ndarray`
Extinction efficiency.
Q_sca : `~jax.numpy.ndarray`
Scattering efficiency.
g : `~jax.numpy.ndarray`
Asymmetry parameter.
"""
# Compute size parameter
x = 2.0 * jnp.pi * r_eff / jnp.maximum(wl_val, 1e-12)
# Compute Rayleigh and MADT efficiencies using modular functions
Q_ext_ray, Q_sca_ray, g_ray = rayleigh(n_val, k_val, x)
Q_ext_madt, Q_sca_madt, g_madt = madt(n_val, k_val, x)
# Smooth blend between Rayleigh (x=1.0) and MADT (x=3.0) using smootherstep
t = jnp.clip((x - 1.0) / 2.0, 0.0, 1.0) # Maps x=1→0, x=3→1
w = 6.0 * t**5 - 15.0 * t**4 + 10.0 * t**3
Q_ext = (1.0 - w) * Q_ext_ray + w * Q_ext_madt
Q_sca = (1.0 - w) * Q_sca_ray + w * Q_sca_madt
g = (1.0 - w) * g_ray + w * g_madt
return Q_ext, Q_sca, g
def _compute_mie_or_zero(
wl_val: jnp.ndarray,
n_val: jnp.ndarray,
k_val: jnp.ndarray,
r_eff: jnp.ndarray,
is_in_support: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Conditionally compute Mie efficiencies or return zeros.
This wrapper uses lax.cond to skip expensive Mie calculations for wavelengths
outside the node support range, improving performance by avoiding unnecessary
computation.
Parameters
----------
wl_val : `~jax.numpy.ndarray`
Wavelength in microns.
n_val : `~jax.numpy.ndarray`
Real part of the refractive index.
k_val : `~jax.numpy.ndarray`
Imaginary part of the refractive index.
r_eff : `~jax.numpy.ndarray`
Effective particle radius in microns.
is_in_support : `~jax.numpy.ndarray`
Boolean indicating if wavelength is within node support range.
Returns
-------
Q_ext : `~jax.numpy.ndarray`
Extinction efficiency (0.0 if outside support).
Q_sca : `~jax.numpy.ndarray`
Scattering efficiency (0.0 if outside support).
g : `~jax.numpy.ndarray`
Asymmetry parameter (0.0 if outside support).
"""
def compute():
return _compute_mie_madt_efficiencies(wl_val, n_val, k_val, r_eff)
def skip():
return (jnp.zeros_like(wl_val), jnp.zeros_like(wl_val), jnp.zeros_like(wl_val))
return jax.lax.cond(is_in_support, compute, skip)
def _compute_mie_madt_efficiencies_masked(
wl: jnp.ndarray,
n: jnp.ndarray,
k: jnp.ndarray,
r_eff: jnp.ndarray,
wl_support_mask: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute Mie/MADT efficiencies with wavelength support masking."""
return jax.vmap(_compute_mie_or_zero, in_axes=(0, 0, 0, None, 0))(
wl, n, k, r_eff, wl_support_mask
)
def compute_cloud_efficiencies_cached_nk(
wl: jnp.ndarray,
r_cm: jnp.ndarray,
n: jnp.ndarray,
k: jnp.ndarray,
*,
eff_scheme: str = "mie_madt",
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Convenience wrapper: compute efficiencies using cached n,k arrays."""
return compute_cloud_efficiencies(wl, r_cm, {}, eff_scheme=eff_scheme, n=n, k=k)
[docs]
def compute_cloud_opacity(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
opacity_scheme: str = "none",
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Main landing function for cloud opacity calculation.
This function dispatches to specific cloud opacity schemes based on the
opacity_scheme parameter. It expects the vertical cloud profile (q_c_lay)
to already be present in the state dictionary (computed by vert_cloud kernels).
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
Atmospheric state dictionary containing:
- `wl` : `~jax.numpy.ndarray`, shape (nwl,)
Wavelength grid in microns.
- `p_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Layer pressures in dyne cm⁻².
- `q_c_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Cloud mass mixing ratio per layer (from vert_cloud kernel).
- `nlay` : int
Number of atmospheric layers.
- `nwl` : int
Number of wavelength points.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing scheme-specific parameters.
Required parameters depend on the chosen opacity_scheme.
opacity_scheme : str, optional
Cloud opacity scheme identifier. Options:
- `"none"` or `"zero"`: No cloud opacity (default)
- `"grey_const"`: Wavelength-independent grey opacity in every layer
- `"grey_profile"`: Wavelength-independent grey opacity masked by q_c_lay
- `"direct_nk"`: Retrieved refractive index with Mie/MADT scattering
- `"F18"`: Fisher & Heng (2018) empirical model
- `"madt_rayleigh"`: Mie/MADT blend using cached n,k on master grid
- `"lxmie"`: Full Lorenz-Mie (LX-MIE) using cached n,k on master grid
- `"powerlaw"`: Grey + power-law wavelength dependence
Returns
-------
k_cld : `~jax.numpy.ndarray`, shape (nlay, nwl)
Cloud extinction coefficient in cm² g⁻¹.
ssa : `~jax.numpy.ndarray`, shape (nlay, nwl)
Single-scattering albedo (Q_sca / Q_ext).
g : `~jax.numpy.ndarray`, shape (nlay, nwl)
Asymmetry parameter for scattering phase function.
"""
scheme_lower = opacity_scheme.lower().strip()
# Dispatch to appropriate scheme
# First check if zero cloud, grey or deck and powerlaw cloud
if scheme_lower in ("none", "zero", "off", "no_cloud"):
return zero_cloud_opacity(state, params)
elif scheme_lower == "grey_const":
return grey_const_cloud(state, params)
elif scheme_lower in ("grey_profile", "grey_slab"):
return grey_profile_cloud(state, params)
elif scheme_lower in ("powerlaw", "power_law", "deck_and_powerlaw"):
return deck_and_powerlaw(state, params)
elif scheme_lower in ("direct_nk", "nk"):
# Keep this pathway self-contained (legacy-style).
return direct_nk(state, params)
elif scheme_lower in ("madt_rayleigh", "madt-rayleigh", "mie_madt"):
return _cached_nk_mie_cloud(state, params, eff_scheme="mie_madt")
elif scheme_lower in ("lxmie", "mie_full", "full_mie"):
return _cached_nk_mie_cloud(state, params, eff_scheme="lxmie")
elif scheme_lower not in ("f18", "fisher18", "fisher_heng"):
raise ValueError(
f"Unknown cloud opacity scheme: '{opacity_scheme}'. "
"Valid options: none, grey_const, grey_profile, powerlaw, direct_nk, f18, madt_rayleigh, lxmie"
)
# ------------------------------------------------------------
# Microphysical clouds: distribution -> efficiencies -> opacity (F18)
# ------------------------------------------------------------
wl = state["wl"] # (nwl,) microns
rho_a = state["rho_lay"] # (nlay,) g cm^-3
q_c = state["q_c_lay"] # (nlay,) dimensionless
q_c = jnp.where(q_c > _QC_EPS, q_c, 0.0)
rho_d = params["cld_rho"] # g cm^-3
# Particle size distribution code:
# 1 = monodisperse, 2 = polydisperse (lognormal)
cloud_dist_code = jnp.asarray(params.get("cloud_dist", 1), dtype=jnp.int32)
eff_scheme = "f18"
# Retrieved / configured radius parameter is in microns.
r_um = 10.0 ** params["log_10_cld_r"]
r_cm = r_um * 1e-4
def _poly_case(_):
# lax.cond traces both branches; use .get default so monodisperse configs
# don't require lognormal params to be present.
sig_g = params.get("cld_sigma", jnp.asarray(1.0))
lnsig2 = jnp.log(sig_g) ** 2
# Total number density implied by q_c for a lognormal (geometric-mean) radius.
N0 = (3.0 * rho_a * q_c) / (4.0 * jnp.pi * rho_d * r_cm**3) * jnp.exp(-4.5 * lnsig2) # (nlay,)
# Radius grid bounds are provided in microns.
log_10_r_min = jnp.log10(1e-3 * 1e-4) #jnp.log10(params["r_min"])
log_10_r_max = jnp.log10(10.0 * 1e-4)#jnp.log10(params["r_max"])
nr = 20 #params["nr"]
r_grid_cm = jnp.logspace(log_10_r_min, log_10_r_max, nr) * 1e-4 # (nr,) cm
# Spectral number density n(r) [cm^-3 cm^-1], evaluated on-the-fly in scan.
ln_sigma = jnp.log(sig_g)
prefac = N0 / (jnp.sqrt(2.0 * jnp.pi) * ln_sigma) # (nlay,)
# Trapezoid weights on a non-uniform grid: integral y(r) dr ~= sum_i w_i * y_i
dr = jnp.diff(r_grid_cm) # (nr-1,)
w_trap = jnp.concatenate(
[
dr[:1] / 2.0,
(dr[:-1] + dr[1:]) / 2.0,
dr[-1:] / 2.0,
],
axis=0,
) # (nr,)
def _accum(carry, r_w):
alpha_ext, alpha_sca, alpha_sca_g = carry
r_i, w_i = r_w # cm, cm
# n(r_i) for each layer [cm^-3 cm^-1]
log_ratio_i = jnp.log(r_i / r_cm) # (nlay,)
exponent_i = -0.5 * (log_ratio_i / ln_sigma) ** 2 # (nlay,)
f_i = prefac * jnp.exp(exponent_i) / r_i # (nlay,)
# Q over wavelength at this radius (nwl,)
Q_ext_i, Q_sca_i, g_i = compute_cloud_efficiencies(wl, r_i, params, eff_scheme=eff_scheme)
area_i = jnp.pi * (r_i**2) # cm^2
dA = (w_i * area_i) # cm^3
# Add contribution: ∫ n(r) Q πr^2 dr -> units cm^-1
alpha_ext = alpha_ext + (f_i[:, None] * Q_ext_i[None, :] * dA)
alpha_sca = alpha_sca + (f_i[:, None] * Q_sca_i[None, :] * dA)
alpha_sca_g = alpha_sca_g + (f_i[:, None] * Q_sca_i[None, :] * g_i[None, :] * dA)
return (alpha_ext, alpha_sca, alpha_sca_g), None
alpha0 = jnp.zeros((rho_a.shape[0], wl.shape[0]), dtype=wl.dtype)
(alpha_ext, alpha_sca, alpha_sca_g), _ = jax.lax.scan(
_accum,
(alpha0, alpha0, alpha0),
(r_grid_cm, w_trap),
)
# Convert to mass opacities [cm^2 g^-1] by dividing by rho_a.
k_ext = alpha_ext / rho_a[:, None]
k_sca = alpha_sca / rho_a[:, None]
# Scattering-weighted asymmetry parameter.
g = _safe_div(alpha_sca_g, alpha_sca)
ssa = _safe_div(k_sca, k_ext)
return k_ext, ssa, g
def _mono_case(_):
# Monodisperse: N0 comes from condensate mass per volume with radius r_cm.
N0 = (3.0 * rho_a * q_c) / (4.0 * jnp.pi * rho_d * r_cm**3) # (nlay,) cm^-3
Q_ext_wl, Q_sca_wl, g_wl = compute_cloud_efficiencies(wl, r_cm, params, eff_scheme=eff_scheme) # (nwl,)
alpha_ext = N0[:, None] * Q_ext_wl[None, :] * (jnp.pi * r_cm**2) # (nlay, nwl) cm^-1
alpha_sca = N0[:, None] * Q_sca_wl[None, :] * (jnp.pi * r_cm**2) # (nlay, nwl) cm^-1
k_ext = alpha_ext / rho_a[:, None]
k_sca = alpha_sca / rho_a[:, None]
ssa = _safe_div(k_sca, k_ext)
g = g_wl[None, :] + jnp.zeros_like(rho_a[:, None])
q_mask = (q_c > 0)[:, None]
k_ext = jnp.where(q_mask, k_ext, 0.0)
ssa = jnp.where(q_mask, ssa, 0.0)
g = jnp.where(q_mask, g, 0.0)
return k_ext, ssa, g
return jax.lax.cond(cloud_dist_code == 2, _poly_case, _mono_case, operand=None)
def _cached_nk_mie_cloud(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
*,
eff_scheme: str,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Cloud opacity from cached n,k + size distribution integration."""
wl = state["wl"] # (nwl,) microns
rho_a = state["rho_lay"] # (nlay,) g cm^-3
q_c = state["q_c_lay"] # (nlay,) dimensionless
q_c = jnp.where(q_c > _QC_EPS, q_c, 0.0)
rho_d = params["cld_rho"] # g cm^-3
n = state["cloud_nk_n"] # (nwl,)
k = state["cloud_nk_k"] # (nwl,)
cloud_dist_code = jnp.asarray(params.get("cloud_dist", 1), dtype=jnp.int32)
r_um = 10.0 ** params["log_10_cld_r"]
r_cm = r_um * 1e-4
has_cloud_any = jnp.any(q_c > 0)
def _mono_case(_):
N0 = (3.0 * rho_a * q_c) / (4.0 * jnp.pi * rho_d * r_cm**3) # (nlay,) cm^-3
Q_ext_wl, Q_sca_wl, g_wl = compute_cloud_efficiencies(wl, r_cm, params, eff_scheme=eff_scheme, n=n, k=k)
alpha_ext = N0[:, None] * Q_ext_wl[None, :] * (jnp.pi * r_cm**2)
alpha_sca = N0[:, None] * Q_sca_wl[None, :] * (jnp.pi * r_cm**2)
k_ext = alpha_ext / rho_a[:, None]
k_sca = alpha_sca / rho_a[:, None]
ssa = _safe_div(k_sca, k_ext)
g = g_wl[None, :] + jnp.zeros_like(rho_a[:, None])
q_mask = (q_c > 0)[:, None]
k_ext = jnp.where(q_mask, k_ext, 0.0)
ssa = jnp.where(q_mask, ssa, 0.0)
g = jnp.where(q_mask, g, 0.0)
return k_ext, ssa, g
def _poly_case(_):
sig_g = params.get("cld_sigma", jnp.asarray(1.0))
lnsig2 = jnp.log(sig_g) ** 2
N0 = (3.0 * rho_a * q_c) / (4.0 * jnp.pi * rho_d * r_cm**3) * jnp.exp(-4.5 * lnsig2) # (nlay,)
# NOTE: r_grid is currently hard-baked/static elsewhere in your setup.
log_10_r_min = 1e-3 * 1e-4
log_10_r_max = 10.0 * 1e-4
nr = 20
r_grid_cm = jnp.logspace(log_10_r_min, log_10_r_max, nr) * 1e-4 # (nr,) cm
ln_sigma = jnp.log(sig_g)
prefac = N0 / (jnp.sqrt(2.0 * jnp.pi) * ln_sigma) # (nlay,)
dr = jnp.diff(r_grid_cm)
w_trap = jnp.concatenate([dr[:1] / 2.0, (dr[:-1] + dr[1:]) / 2.0, dr[-1:] / 2.0], axis=0)
def _accum(carry, r_w):
alpha_ext, alpha_sca, alpha_sca_g = carry
r_i, w_i = r_w
log_ratio_i = jnp.log(r_i / r_cm)
exponent_i = -0.5 * (log_ratio_i / ln_sigma) ** 2
f_i = prefac * jnp.exp(exponent_i) / r_i
Q_ext_i, Q_sca_i, g_i = compute_cloud_efficiencies(wl, r_i, params, eff_scheme=eff_scheme, n=n, k=k)
dA = w_i * jnp.pi * (r_i**2)
alpha_ext = alpha_ext + (f_i[:, None] * Q_ext_i[None, :] * dA)
alpha_sca = alpha_sca + (f_i[:, None] * Q_sca_i[None, :] * dA)
alpha_sca_g = alpha_sca_g + (f_i[:, None] * Q_sca_i[None, :] * g_i[None, :] * dA)
return (alpha_ext, alpha_sca, alpha_sca_g), None
alpha0 = jnp.zeros((rho_a.shape[0], wl.shape[0]), dtype=wl.dtype)
(alpha_ext, alpha_sca, alpha_sca_g), _ = jax.lax.scan(_accum, (alpha0, alpha0, alpha0), (r_grid_cm, w_trap))
k_ext = alpha_ext / rho_a[:, None]
k_sca = alpha_sca / rho_a[:, None]
ssa = _safe_div(k_sca, k_ext)
g = _safe_div(alpha_sca_g, alpha_sca)
q_mask = (q_c > 0)[:, None]
k_ext = jnp.where(q_mask, k_ext, 0.0)
ssa = jnp.where(q_mask, ssa, 0.0)
g = jnp.where(q_mask, g, 0.0)
return k_ext, ssa, g
def _do_cloud(_):
return jax.lax.cond(cloud_dist_code == 2, _poly_case, _mono_case, operand=None)
def _skip_cloud(_):
zeros = jnp.zeros((q_c.shape[0], wl.shape[0]), dtype=wl.dtype)
return zeros, zeros, zeros
return jax.lax.cond(has_cloud_any, _do_cloud, _skip_cloud, operand=None)
[docs]
def zero_cloud_opacity(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return zero-valued cloud optical properties.
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
State dictionary containing:
- `nlay` : int
Number of atmospheric layers.
- `nwl` : int
Number of wavelength points.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary (unused; kept for API compatibility).
Returns
-------
k_cld : `~jax.numpy.ndarray`, shape (nlay, nwl)
Cloud extinction coefficient in cm² g⁻¹ (all zeros).
ssa : `~jax.numpy.ndarray`, shape (nlay, nwl)
Single-scattering albedo (all zeros).
g : `~jax.numpy.ndarray`, shape (nlay, nwl)
Asymmetry parameter (all zeros).
"""
del params
# Use shape directly without int() conversion for JIT compatibility
shape = (state["nlay"], state["nwl"])
k_cld = jnp.zeros(shape)
ssa = jnp.zeros(shape)
g = jnp.zeros(shape)
return k_cld, ssa, g
[docs]
def grey_const_cloud(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute a globally constant grey cloud opacity.
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
State dictionary containing scalar entries `nlay` and `nwl`.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `log_10_k_cld_grey` : float
Log₁₀ of the grey cloud extinction coefficient in cm² g⁻¹.
Returns
-------
k_cld : `~jax.numpy.ndarray`, shape (nlay, nwl)
Grey cloud extinction coefficient in cm² g⁻¹ in every layer.
ssa : `~jax.numpy.ndarray`, shape (nlay, nwl)
Single-scattering albedo (zeros; pure absorption).
g : `~jax.numpy.ndarray`, shape (nlay, nwl)
Asymmetry parameter (zeros).
"""
# Use shape directly without int() conversion for JIT compatibility
shape = (state["nlay"], state["nwl"])
opacity_value = 10.0**params["log_10_k_cld_grey"]
k_cld = jnp.full(shape, opacity_value)
ssa = jnp.zeros(shape)
g = jnp.zeros(shape)
return k_cld, ssa, g
[docs]
def grey_profile_cloud(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute grey cloud opacity masked by the vertical cloud profile.
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
Atmospheric state dictionary containing:
- ``nlay`` : int
Number of atmospheric layers.
- ``nwl`` : int
Number of wavelength points.
- ``q_c_lay`` : `~jax.numpy.ndarray`, shape (nlay,)
Cloud mass mixing ratio per layer from the selected ``vert_cloud``
kernel. Layers with values greater than zero receive grey opacity.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- ``log_10_k_cld_grey`` : float
Log10 grey cloud extinction coefficient in cm² g⁻¹.
Returns
-------
k_cld : `~jax.numpy.ndarray`, shape (nlay, nwl)
Grey cloud extinction coefficient in cm² g⁻¹, set to
``10**log_10_k_cld_grey`` where ``q_c_lay > 0`` and zero elsewhere.
ssa : `~jax.numpy.ndarray`, shape (nlay, nwl)
Single-scattering albedo (zeros; pure absorption).
g : `~jax.numpy.ndarray`, shape (nlay, nwl)
Asymmetry parameter (zeros).
"""
shape = (state["nlay"], state["nwl"])
opacity_value = 10.0**params["log_10_k_cld_grey"]
q_c_lay = state["q_c_lay"]
cloud_mask = q_c_lay > 0.0
k_cld = jnp.where(cloud_mask[:, None], opacity_value, 0.0) + jnp.zeros(shape)
ssa = jnp.zeros(shape)
g = jnp.zeros(shape)
return k_cld, ssa, g
[docs]
def deck_and_powerlaw(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
wl = state["wl"]
nlay = state["nlay"]
# Constant grey opacity component
k_grey = 10.0**params["log_10_k_cld_grey"]
# Power-law amplitude at reference wavelength
k_powerlaw = 10.0**params["log_10_k_cld_Ray"]
# Power-law exponent (alpha=4 gives Rayleigh slope)
alpha = params["alpha_cld"]
# Reference wavelength
wl_ref = params["wl_ref_cld"]
# Two-component opacity: grey + power-law
# k(λ) = k_grey + k_powerlaw * (λ/λ_ref)^(-alpha)
k_wl = k_grey + k_powerlaw * (wl / wl_ref)**(-alpha)
# Broadcast to (nlay, nwl) using implicit broadcasting
k_cld = jnp.zeros((nlay, 1)) + k_wl[None, :]
# Pure absorption (no scattering)
ssa = jnp.zeros_like(k_cld)
g = jnp.zeros_like(k_cld)
return k_cld, ssa, g
[docs]
def F18_cloud(
wl: jnp.ndarray,
r_cm: jnp.ndarray,
params: Dict[str, jnp.ndarray],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Fisher & Heng (2018) extinction-efficiency model (optics only).
This is an "optics engine" that returns (Q_ext, Q_sca, g). For now it
assumes pure absorption (Q_sca = 0, g = 0).
Parameters
----------
wl : `~jax.numpy.ndarray`, shape (nwl,)
Wavelength grid in microns.
r_cm : `~jax.numpy.ndarray`, shape (nr,) or scalar
Particle radius in cm.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `cld_Q0`, `cld_Q1`, `cld_a`, `cld_sigma` (sigma currently unused here)
Returns
-------
Q_ext : `~jax.numpy.ndarray`, shape (nr, nwl) or (nwl,) if r_cm is scalar
Extinction efficiency.
Q_sca : `~jax.numpy.ndarray`, same shape as Q_ext
Scattering efficiency (zeros).
g : `~jax.numpy.ndarray`, same shape as Q_ext
Asymmetry parameter (zeros).
"""
Q0 = params["cld_Q0"]
Q1 = params["cld_Q1"]
a = params["cld_a"]
# Convert radius to microns to match wl units for size parameter.
r_um = r_cm * 1e4
# Broadcast to (nr, nwl) when r_cm is a vector.
if r_um.ndim == 0:
x = (2.0 * jnp.pi * r_um) / jnp.maximum(wl, 1e-30)
else:
x = (2.0 * jnp.pi * r_um[:, None]) / jnp.maximum(wl[None, :], 1e-30)
x = jnp.maximum(x, 1e-30)
Q_ext = Q1 / (Q0 * x ** (-a) + x**0.2)
Q_sca = jnp.zeros_like(Q_ext)
g = jnp.zeros_like(Q_ext)
return Q_ext, Q_sca, g
[docs]
def direct_nk(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Compute cloud optical properties from retrieved refractive-index nodes.
This function retrieves node values describing the complex refractive
index (n, k) as a function of wavelength, interpolates them onto the model
wavelength grid, and computes wavelength-dependent optical properties using
Mie/MADT scattering. The vertical profile is provided via q_c_lay in the
state dictionary (computed separately by vert_cloud kernels).
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
Atmospheric state dictionary containing:
- `wl` : `~jax.numpy.ndarray`, shape (nwl,)
Wavelength grid in microns.
- `q_c_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Cloud mass mixing ratio per layer (from vert_cloud kernel).
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `wl_node_0`..`wl_node_12` : float
Wavelength nodes (microns).
- `n_0`..`n_12` : float
Real refractive-index nodes.
- `log_10_k_0`..`log_10_k_12` : float
Log₁₀ imaginary refractive-index nodes.
- `log_10_cld_r` : float
Log₁₀ particle radius in microns.
- `cld_rho` : float
Cloud bulk density in g cm⁻³.
Returns
-------
k_cld : `~jax.numpy.ndarray`, shape (nlay, nwl)
Cloud extinction coefficient in cm² g⁻¹.
ssa : `~jax.numpy.ndarray`, shape (nlay, nwl)
Single-scattering albedo derived from (Q_sca / Q_ext).
g : `~jax.numpy.ndarray`, shape (nlay, nwl)
Asymmetry parameter (zeros in this implementation).
"""
wl = state["wl"] # (nwl,) in micron
q_c_lay = state["q_c_lay"] # (nlay,)
q_c_lay = jnp.where(q_c_lay > _QC_EPS, q_c_lay, 0.0)
has_cloud_any = jnp.any(q_c_lay > 0)
# -----------------------------------------------------------------------
# Retrieved / configured knobs
# -----------------------------------------------------------------------
r_eff = 10.0 ** params["log_10_cld_r"] # particle radius (um)
cld_rho = params["cld_rho"] # Cloud bulk density, defaults to 1.0 g/cm³
# Keep n positive for scattering math sanity (doesn't forbid n<1)
n_floor = 1e-6
# -----------------------------
# Retrieve k(wl) from log-nodes
# -----------------------------
# Use jnp.stack instead of list comprehension for efficiency
wl_nodes = jnp.stack([params[f"wl_node_{i}"] for i in range(13)])
# Limit nk contribution to the wavelength span covered by the nodes
wl_support_min = jnp.min(wl_nodes)
wl_support_max = jnp.max(wl_nodes)
wl_support_mask = jnp.logical_and(wl >= wl_support_min, wl <= wl_support_max)
# Retrieve n(wl) / k(wl) node values using jnp.stack
n_nodes = jnp.stack([params[f"n_{i}"] for i in range(13)])
log10_k_nodes = jnp.stack([params[f"log_10_k_{i}"] for i in range(13)])
n_interp = pchip_1d(wl, wl_nodes, n_nodes)
log10_k_interp = pchip_1d(wl, wl_nodes, log10_k_nodes)
n = jnp.maximum(n_interp, n_floor)
k = jnp.maximum(10.0 ** log10_k_interp, 1e-12)
n = jnp.where(wl_support_mask, n, n_floor)
k = jnp.where(wl_support_mask, k, 1e-12)
def _do_cloud(_):
# Compute Mie/MADT efficiencies conditionally (skip wavelengths outside node support)
Q_ext_vals, Q_sca_vals, g_vals = jax.vmap(_compute_mie_or_zero, in_axes=(0, 0, 0, None, 0))(
wl, n, k, r_eff, wl_support_mask
)
# Compute cloud opacity using vertical profile from state
k_cld = (
(3.0 * q_c_lay[:, None] * Q_ext_vals[None, :])
/ (4.0 * cld_rho * (r_eff * 1e-4))
)
ssa_wl = jnp.clip(Q_sca_vals / jnp.maximum(Q_ext_vals, 1e-30), 0.0, 1.0)
ssa = ssa_wl[None, :] + jnp.zeros_like(q_c_lay[:, None])
g = g_vals[None, :] + jnp.zeros_like(q_c_lay[:, None])
q_mask = (q_c_lay > 0)[:, None]
k_cld = jnp.where(q_mask, k_cld, 0.0)
ssa = jnp.where(q_mask, ssa, 0.0)
g = jnp.where(q_mask, g, 0.0)
return k_cld, ssa, g
def _skip_cloud(_):
zeros = jnp.zeros((q_c_lay.shape[0], wl.shape[0]), dtype=wl.dtype)
return zeros, zeros, zeros
return jax.lax.cond(has_cloud_any, _do_cloud, _skip_cloud, operand=None)
[docs]
def nk_f18_blend(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Additive combination of direct_nk and F18 cloud opacity.
Both components share the same vertical cloud profile (``q_c_lay`` from
``state``). The ``direct_nk`` component uses the standard node parameters
(``wl_node_*``, ``n_*``, ``log_10_k_*``, ``log_10_cld_r``, ``cld_rho``).
The F18 component uses namespaced parameters to avoid collision:
``log_10_cld_r_f18``, ``cld_rho_f18``, ``cld_Q0``, ``cld_Q1``, ``cld_a``.
Extinction is additive. The combined single-scattering albedo and
asymmetry parameter are weighted by each component's scattering opacity
(k_ext × ssa), so the result is physically consistent regardless of the
relative magnitudes of the two contributions.
Parameters
----------
state : dict
Atmospheric state; must contain ``wl``, ``q_c_lay``, ``nlay``, ``nwl``.
params : dict
Must contain all ``direct_nk`` params plus the F18-namespaced params
listed above.
Returns
-------
k_cld : jnp.ndarray, shape (nlay, nwl)
Combined cloud extinction coefficient in cm² g⁻¹.
ssa : jnp.ndarray, shape (nlay, nwl)
Combined single-scattering albedo.
g : jnp.ndarray, shape (nlay, nwl)
Combined asymmetry parameter (flux-weighted over scattering opacities).
"""
# --- direct_nk component ---
k_nk, ssa_nk, g_nk = direct_nk(state, params)
# --- F18 component ---
wl = state["wl"]
q_c_lay = state["q_c_lay"]
q_c_lay = jnp.where(q_c_lay > _QC_EPS, q_c_lay, 0.0)
r_um_f18 = 10.0 ** params["log_10_cld_r_f18"]
r_cm_f18 = r_um_f18 * 1e-4
cld_rho_f18 = params["cld_rho_f18"]
Q_ext_f18, Q_sca_f18, g_eff_f18 = F18_cloud(wl, r_cm_f18, params)
# Convert efficiencies to mass extinction coefficient (cm² g⁻¹):
# k = (3 q_c Q_ext) / (4 rho r)
k_f18_wl = (3.0 * Q_ext_f18) / (4.0 * cld_rho_f18 * r_cm_f18) # (nwl,)
k_f18 = q_c_lay[:, None] * k_f18_wl[None, :] # (nlay, nwl)
ssa_f18_wl = jnp.clip(
Q_sca_f18 / jnp.maximum(Q_ext_f18, _DIV_EPS), 0.0, 1.0
)
ssa_f18 = ssa_f18_wl[None, :] + jnp.zeros_like(k_f18)
g_f18 = g_eff_f18[None, :] + jnp.zeros_like(k_f18)
# Zero out layers with no cloud mass
q_mask = (q_c_lay > 0)[:, None]
k_f18 = jnp.where(q_mask, k_f18, 0.0)
ssa_f18 = jnp.where(q_mask, ssa_f18, 0.0)
g_f18 = jnp.where(q_mask, g_f18, 0.0)
# --- Additive combination ---
k_total = k_nk + k_f18
# Flux-weighted ssa and g: conserve scattered power and phase function
sca_nk = ssa_nk * k_nk
sca_f18 = ssa_f18 * k_f18
sca_total = sca_nk + sca_f18
ssa_total = _safe_div(sca_total, k_total)
g_total = _safe_div(
g_nk * sca_nk + g_f18 * sca_f18,
jnp.maximum(sca_total, _DIV_EPS),
)
return k_total, ssa_total, g_total
[docs]
def f18_skew_cloud(
state: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""F18 continuum plus a skew-normal spectral feature with a Rayleigh size window.
The extinction efficiency is
Q_ext(λ) = clip(Q_cont(λ) + W(x) · Q_feat(λ), 0, _Q_EXT_MAX)
where the continuum follows Fisher & Heng (2018),
Q_cont = Q1 / (Q0 · x^{-a} + x^{0.2}), x = 2π r / λ,
the spectral feature is a skew-normal profile in wavelength,
Q_feat = 2 A exp(-z²/2) Φ(ξ z), z = (λ - λ0) / ω,
with Φ the standard normal CDF, and the size-parameter window
W(x) = exp(-x / x0)
suppresses the feature for large particles (geometric-optics regime),
consistent with the Rayleigh origin of the absorption feature.
Pure absorption is assumed (Q_sca = 0, g = 0).
Parameters
----------
state : dict
Must contain ``wl`` (μm), ``q_c_lay``, ``nlay``, ``nwl``.
params : dict
``log_10_cld_r`` — log₁₀ particle radius (μm)
``cld_rho`` — bulk density (g cm⁻³)
``cld_Q0`` — continuum opacity scale
``cld_Q1`` — continuum opacity scale
``cld_a`` — continuum power-law index
``cld_amp`` — skew-normal feature amplitude A
``cld_lam0`` — feature central wavelength λ0 (μm)
``cld_omega`` — feature width ω (μm)
``cld_xi`` — skewness parameter ξ (0 → symmetric Gaussian)
``cld_x0`` — size-window rolloff scale x0
Returns
-------
k_cld : jnp.ndarray, shape (nlay, nwl)
Cloud extinction coefficient in cm² g⁻¹.
ssa : jnp.ndarray, shape (nlay, nwl)
Single-scattering albedo (zeros — pure absorption).
g : jnp.ndarray, shape (nlay, nwl)
Asymmetry parameter (zeros).
"""
wl = state["wl"] # (nwl,) μm
q_c_lay = state["q_c_lay"] # (nlay,)
q_c_lay = jnp.where(q_c_lay > _QC_EPS, q_c_lay, 0.0)
r_um = 10.0 ** params["log_10_cld_r"]
r_cm = r_um * 1e-4
rho = params["cld_rho"]
Q0 = params["cld_Q0"]
Q1 = params["cld_Q1"]
a = params["cld_a"]
amp = params["cld_amp"]
lam0 = params["cld_lam0"]
omega = params["cld_omega"]
xi = params["cld_xi"]
x0 = params["cld_x0"]
# Size parameter (nwl,)
x = (2.0 * jnp.pi * r_um) / jnp.maximum(wl, _DIV_EPS)
# F18 continuum
Q_cont = Q1 / (Q0 * x ** (-a) + x ** 0.2)
# Skew-normal feature: 2 A exp(-z^2/2) Phi(xi z)
z = (wl - lam0) / jnp.maximum(omega, _DIV_EPS)
Phi_xi_z = 0.5 * (1.0 + jax.scipy.special.erf(xi * z / jnp.sqrt(2.0)))
Q_feat = 2.0 * amp * jnp.exp(-0.5 * z ** 2) * Phi_xi_z
# Rayleigh size window: suppresses feature for large particles.
W = jnp.exp(-x / jnp.maximum(x0, _DIV_EPS))
# Combined, clipped to guard against unphysical values.
Q_ext = jnp.clip(Q_cont + W * Q_feat, 0.0, _Q_EXT_MAX)
k_wl = (3.0 * Q_ext) / (4.0 * rho * r_cm) # (nwl,)
k_cld = q_c_lay[:, None] * k_wl[None, :] # (nlay, nwl)
zeros = jnp.zeros_like(k_cld)
return k_cld, zeros, zeros