"""
vert_Tp.py
==========
"""
from __future__ import annotations
from typing import Dict, Tuple
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from .data_constants import bar
# ---------------- Hopf function ----------------
FIT_P = jnp.asarray([0.6162, -0.3799, 2.395, -2.041, 2.578])
FIT_Q = jnp.asarray([-0.9799, 3.917, -3.17, 3.69])
__all__ = [
"hopf_function",
"isothermal",
"Barstow",
"Milne",
"Modified_Milne",
"Guillot",
"Modified_Guillot",
"MandS",
"picket_fence",
"dry_convective_adjustment"
]
[docs]
def hopf_function(tau: jnp.ndarray) -> jnp.ndarray:
"""Compute the Hopf function for radiative transfer.
This function provides a rational polynomial approximation for the Hopf
function, used in analytical T-P profiles.
Parameters
----------
tau : `~jax.numpy.ndarray`
Optical depth (dimensionless).
Returns
-------
hopf : `~jax.numpy.ndarray`
Hopf function value at the given optical depth.
"""
tau = jnp.asarray(tau)
tiny = jnp.finfo(tau.dtype).tiny
tau_safe = jnp.maximum(tau, tiny)
x = jnp.log10(tau_safe)
# Rational fit in x via Horner
p0, p1, p2, p3, p4 = FIT_P
q0, q1, q2, q3 = FIT_Q
num = ((((p0 * x + p1) * x + p2) * x + p3) * x + p4)
den = ((((1.0 * x + q0) * x + q1) * x + q2) * x + q3)
mid = num / den
# Low-tau patch (linear in tau)
low = 0.577351 + (tau_safe - 0.0) * (0.588236 - 0.577351) / (0.01 - 0.0)
# High-tau patch (linear in log10(tau)) -- corrected denominator
x0 = jnp.log10(5.0)
x1 = jnp.log10(10000.0)
high = 0.710398 + (x - x0) * (0.710446 - 0.710398) / (x1 - x0)
out = jnp.where(tau_safe < 0.01, low, mid)
out = jnp.where(tau_safe > 5.0, high, out)
return out
[docs]
def isothermal(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate an isothermal temperature profile.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_iso` : float
Isothermal temperature in Kelvin.
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
nlev = jnp.size(p_lev)
# Parameter values are already JAX arrays, no need to wrap
T_iso = params["T_iso"]
T_lev = jnp.full((nlev,), T_iso)
T_lay = jnp.full((nlev-1,), T_iso)
return T_lev, T_lay
[docs]
def Barstow(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a Barstow et al. (2020) temperature profile.
This profile is isothermal at low pressures, follows an adiabat in the
middle, and becomes isothermal again at high pressures.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_strat` : float
Upper-atmosphere isothermal temperature in Kelvin.
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
# Parameter values are already JAX arrays, no need to wrap
T_strat = params["T_strat"]
kappa = 2.0 / 7.0
p1 = 0.1 * bar
p2 = 1.0 * bar
p_for_adiabat = jnp.maximum(p_lev, p1)
T_adiabat = T_strat * (p_for_adiabat / p1) ** kappa
T_deep = T_strat * (p2 / p1) ** kappa
T_lev = jnp.where(p_lev <= p1, T_strat, jnp.where(p_lev <= p2, T_adiabat, T_deep))
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def Modified_Milne(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a modified Milne temperature profile with stretched exponential transition.
This profile uses a grey optical depth model with a stretched exponential
transition from a skin temperature at low pressure to the standard Milne
profile at high pressure.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels in dyne cm⁻².
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_int` : float
Internal temperature in Kelvin.
- `log_10_tau_ref` : float
Log₁₀ infrared optical depth at reference pressure (1 bar, dimensionless).
- `T_ratio` : float
Skin-to-internal temperature ratio (T_skin / T_int, dimensionless).
- `log_10_p_t` : float
Log₁₀ transition pressure in bar.
- `beta` : float
Stretching exponent for transition (0 < beta <= 1, dimensionless).
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
g = 10.0**params["log_10_g"]
T_int = params["T_int"]
k_ir = 10.0**params["log_10_k_ir"]
T_ratio = params["T_ratio"]
p_t = (10.0**params["log_10_p_t"]) * bar
beta = params["beta"]
tau_ir = k_ir / g * p_lev
q_inf = 0.710446
q0 = (4.0 / 3.0) * (T_ratio**4)
sig = jnp.exp(-((p_lev / p_t) ** beta))
q = q_inf + (q0 - q_inf) * sig
T_lev = (0.75 * T_int**4 * (q + tau_ir)) ** 0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def Milne(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a Milne temperature profile for an internally heated atmosphere.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `log_10_g` : float
Log₁₀ surface gravity in cm s⁻².
- `T_int` : float
Internal temperature in Kelvin.
- `k_ir` : float
Infrared opacity in cm² g⁻¹.
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
# Parameter values are already JAX arrays, no need to wrap
g = 10.0**params["log_10_g"]
T_int = params["T_int"]
k_ir = 10.0**params["log_10_k_ir"]
tau_ir = k_ir / g * p_lev
T_lev = (0.75 * T_int**4 * (hopf_function(tau_ir) + tau_ir)) ** 0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def Guillot(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a Guillot (2010) analytical temperature profile.
This profile combines internal heating and external irradiation using a
two-stream approximation with separate visible and infrared opacities.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_int` : float
Internal temperature in Kelvin.
- `T_eq` : float
Equilibrium temperature in Kelvin.
- `log_10_k_ir` : float
Log₁₀ infrared opacity in cm² g⁻¹.
- `log_10_gam_v` : float
Log₁₀ visible-to-IR opacity ratio (dimensionless).
- `log_10_g` : float
Log₁₀ surface gravity in cm s⁻².
- `f_hem` : float
Hemispheric redistribution factor (dimensionless).
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
# Parameter values are already JAX arrays, no need to wrap
T_int = params["T_int"]
T_eq = params["T_eq"]
k_ir = 10.0**params["log_10_k_ir"]
gam = 10.0**params["log_10_gam_v"]
g = 10.0**params["log_10_g"]
f = params["f_hem"]
tau_ir = k_ir / g * p_lev
sqrt3 = jnp.sqrt(3.0)
milne = 0.75 * T_int**4 * (2.0 / 3.0 + tau_ir)
guillot = 0.75 * T_eq**4 * 4.0*f * (
2.0 / 3.0
+ 1.0 / (gam * sqrt3)
+ (gam / sqrt3 - 1.0 / (gam * sqrt3)) * jnp.exp(-gam * tau_ir * sqrt3)
)
T_lev = (milne + guillot) ** 0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def Modified_Guillot(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a modified Guillot profile with a flexible irradiated Hopf term.
This profile keeps the Guillot (2010) semi-grey structure but replaces the
fixed 2/3 term in the irradiation component with a stretched-exponential
Hopf-like transition in pressure.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_int` : float
Internal temperature in Kelvin.
- `T_eq` : float
Equilibrium temperature in Kelvin.
- `log_10_k_ir` : float
Log10 infrared opacity in cm^2 g^-1.
- `log_10_gam_v` : float
Log10 visible-to-IR opacity ratio.
- `log_10_g` : float
Log10 surface gravity in cm s^-2.
- `f_hem` : float
Hemispheric redistribution factor.
- `q_irr_0` : float
Irradiated Hopf value at low optical depth.
- `log_10_p_t` : float
Log10 transition pressure in bar.
- `beta` : float
Stretching exponent for the Hopf transition.
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
T_int = params["T_int"]
T_eq = params["T_eq"]
k_ir = 10.0 ** params["log_10_k_ir"]
gam = 10.0 ** params["log_10_gam_v"]
g = 10.0 ** params["log_10_g"]
f = params["f_hem"]
q_irr_0 = params["q_irr_0"]
p_t = (10.0 ** params["log_10_p_t"]) * bar
beta = params["beta"]
tau_ir = (k_ir / g) * p_lev
sqrt3 = jnp.sqrt(3.0)
# Keep internal Hopf at the Eddington limit.
q_int = 2.0 / 3.0
# Flexible irradiated Hopf term: q_irr -> 2/3 at depth.
q_irr_inf = 2.0/3.0
sig = jnp.exp(-((p_lev / p_t) ** beta))
q_irr = q_irr_inf + (q_irr_0 - q_irr_inf) * sig
internal = 0.75 * T_int**4 * (q_int + tau_ir)
irradiated = 0.75 * T_eq**4 * 4.0 * f * (
q_irr
+ 1.0 / (gam * sqrt3)
+ (gam / sqrt3 - 1.0 / (gam * sqrt3)) * jnp.exp(-gam * tau_ir * sqrt3)
)
T_lev = jnp.maximum(internal + irradiated, 0.0) ** 0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def MandS(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a Madhusudhan & Seager (2009) three-region T-P profile.
This profile divides the atmosphere into three regions defined by
pressure boundaries and slope parameters.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `a1`, `a2` : float
Shape/slope parameters controlling the inversion strength.
- `log_10_P1`, `log_10_P2`, `log_10_P3` : float
Transition pressures in log₁₀(bar). Profile is computed in log₁₀
space throughout, matching the Madhusudhan & Seager (2009) convention.
- `T_ref` : float
Reference temperature at the top of the atmosphere in Kelvin.
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
p_lev = jnp.asarray(p_lev)
a1 = params["a1"]
a2 = params["a2"]
log_P1 = params["log_10_P1"] # log10(bar), no unit conversion needed
log_P2 = params["log_10_P2"]
log_P3 = params["log_10_P3"]
T0 = params["T_ref"]
# Work entirely in log10(bar) space
log_P = jnp.log10(p_lev / bar)
log_P0 = jnp.min(log_P) # TOA
def inv_sq(lp, lp_ref, a):
a_safe = jnp.where(jnp.abs(a) > 1e-12, a, 1e-12)
return ((lp - lp_ref) / a_safe) ** 2
# Continuity
T1 = T0 + inv_sq(log_P1, log_P0, a1)
T2 = T1 - inv_sq(log_P1, log_P2, a2)
T3 = T2 + inv_sq(log_P3, log_P2, a2)
# Piecewise inversion T(P)
T_reg1 = T0 + inv_sq(log_P, log_P0, a1) # P0 < P <= P1
T_reg2 = T2 + inv_sq(log_P, log_P2, a2) # P1 < P <= P3
in_reg1 = log_P <= log_P1
in_reg2 = (log_P > log_P1) & (log_P <= log_P3)
T_lev = jnp.where(in_reg1, T_reg1, jnp.where(in_reg2, T_reg2, T3))
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
def Line(p_lev: jnp.ndarray,params: Dict[str, jnp.ndarray],) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Line et al. (2013) (two visible channels) analytic T(p) profile.
Parameters
----------
p_lev : jnp.ndarray, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, jnp.ndarray]
Expected keys (all scalar JAX arrays):
- "T_int" : internal temperature [K]
- "T_eq" : equilibrium temperature [K]
- "f_hem" : redistribution factor
- "log_10_k_ir" : log10 thermal/IR opacity [cm^2 g^-1]
- "log_10_g" : log10 gravity [cm s^-2]
- "log_10_gam_v1": log10(γ1) visible/thermal opacity ratio (channel 1)
- "log_10_gam_v2": log10(γ2) visible/thermal opacity ratio (channel 2)
- "alpha" : partition between visible channels, α in [0, 1]
Returns
-------
T_lev : jnp.ndarray, shape (nlev,)
Temperature at levels [K]
T_lay : jnp.ndarray, shape (nlev-1,)
Temperature at layer midpoints [K]
"""
T_int = params["T_int"]
T_eq = params["T_eq"]
f = params["f_hem"]
k_ir = 10.0 ** params["log_10_k_ir"]
g = 10.0 ** params["log_10_g"]
gam1 = 10.0 ** params["log_10_gam_v1"]
gam2 = 10.0 ** params["log_10_gam_v2"]
alpha = params["alpha"]
tau = (k_ir / g) * p_lev
T_irr4 = (4.0 * f) * (T_eq**4)
def xi_gamma(tau_: jnp.ndarray, gamma: jnp.ndarray) -> jnp.ndarray:
x = gamma * tau_
# E2(x) = exponential integral of order 2
E2 = jsp.special.expn(2, x)
term0 = 2.0 / 3.0
term1 = (2.0 / (3.0 * gamma)) * (1.0 + (x / 2.0 - 1.0) * jnp.exp(-x))
term2 = (2.0 * gamma / 3.0) * (1.0 - (tau_**2) / 2.0) * E2
return term0 + term1 + term2
# Eq. (13)
T4 = (3.0 * T_int**4 / 4.0) * (2.0 / 3.0 + tau)
T4 += (3.0 * T_irr4 / 4.0) * (1.0 - alpha) * xi_gamma(tau, gam1)
T4 += (3.0 * T_irr4 / 4.0) * alpha * xi_gamma(tau, gam2)
T_lev = T4 ** 0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def picket_fence(p_lev: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generate a Parmentier & Guillot (2014,2015) picket fence T-P profile.
This profile uses a picket fence approximation for radiative transfer,
treating opacity as a combination of discrete spectral bins.
Parameters
----------
p_lev : `~jax.numpy.ndarray`, shape (nlev,)
Pressure at atmospheric levels.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `T_int`, `T_eq` : float
Internal and equilibrium temperatures in Kelvin.
- `log_10_k_ir`, `log_10_gam_v` : float
Log₁₀ infrared opacity (cm² g⁻¹) and log₁₀ visible-to-IR ratio.
- `log_10_R`, `Beta` : float
Picket-fence parameters (dimensionless).
- `log_10_g` : float
Log₁₀ surface gravity in cm s⁻².
- `f_hem` : float
Hemispheric redistribution factor (dimensionless).
Returns
-------
T_lev : `~jax.numpy.ndarray`, shape (nlev,)
Temperature at levels in Kelvin.
T_lay : `~jax.numpy.ndarray`, shape (nlev-1,)
Temperature at layer midpoints in Kelvin.
"""
# Parameter values are already JAX arrays, no need to wrap
T_int = params["T_int"]
T_eq = params["T_eq"]
k_ir = 10.0**params["log_10_k_ir"]
gam_v = 10.0**params["log_10_gam_v"]
R = 10.0**params["log_10_R"]
B = params["Beta"]
g = 10.0**params["log_10_g"]
f = params["f_hem"]
tau_ir = k_ir / g * p_lev
mu = 1.0/jnp.sqrt(3.0)
gv = gam_v / mu
s = B + R - B * R
gam_p = s + s / R - (s * s) / R
gam_1 = s
gam_2 = s / R
tau_lim = (jnp.sqrt(R) * jnp.sqrt(B * (1.0 - B) * (R - 1.0) ** 2 + R)) / (jnp.sqrt(3.0) * s ** 2)
At1 = gam_1**2 * jnp.log(1.0 + 1.0 / (tau_lim * gam_1))
At2 = gam_2**2 * jnp.log(1.0 + 1.0 / (tau_lim * gam_2))
Av1 = gam_1**2 * jnp.log(1.0 + gv / gam_1)
Av2 = gam_2**2 * jnp.log(1.0 + gv / gam_2)
a0 = 1.0 / gam_1 + 1.0 / gam_2
a1 = -(1.0 / (3.0 * tau_lim**2)) * (
(gam_p / (1.0 - gam_p)) * ((gam_1 + gam_2 - 2.0) / (gam_1 + gam_2))
+ (gam_1 + gam_2) * tau_lim
- (At1 + At2) * tau_lim**2
)
den_v = (1.0 - (gv**2) * (tau_lim**2))
num_a2 = (
(3.0 * gam_1**2 - gv**2) * (3.0 * gam_2**2 - gv**2) * (gam_1 + gam_2)
- 3.0 * gv * (6.0 * gam_1**2 * gam_2**2 - gv**2 * (gam_1**2 + gam_2**2))
)
a2 = (tau_lim**2 / (gam_p * gv**2)) * (num_a2 / den_v)
a3 = -(
tau_lim**2
* (3.0 * gam_1**2 - gv**2)
* (3.0 * gam_2**2 - gv**2)
* (Av2 + Av1)
) / (gam_p * gv**3 * den_v)
term_b0 = (
(gam_1 * gam_2 / (gam_1 - gam_2)) * (At1 - At2) / 3.0
- (gam_1 * gam_2) ** 2 / jnp.sqrt(3.0 * gam_p)
- (gam_1 * gam_2) ** 3 / ((1.0 - gam_1) * (1.0 - gam_2) * (gam_1 + gam_2))
)
b0 = 1.0 / term_b0
b1 = (
gam_1 * gam_2
* (3.0 * gam_1**2 - gv**2)
* (3.0 * gam_2**2 - gv**2)
* tau_lim**2
) / (gam_p * gv**2 * (gv**2 * tau_lim**2 - 1.0))
b2 = (3.0 * (gam_1 + gam_2) * gv**3) / (
(3.0 * gam_1**2 - gv**2) * (3.0 * gam_2**2 - gv**2)
)
b3 = (Av2 - Av1) / (gv * (gam_1 - gam_2))
# ---------- A..E (eqs 77-81) ----------
A_pf = (a0 + a1 * b0) / 3.0
B_pf = -(1.0 / 3.0) * ((gam_1 * gam_2) ** 2 / gam_p) * b0
C_pf = -(1.0 / 3.0) * (b0 * b1 * (1.0 + b2 + b3) * a1 + a2 + a3)
D_pf = (1.0 / 3.0) * ((gam_1 * gam_2) ** 2 / gam_p) * b0 * b1 * (1.0 + b2 + b3)
E_pf = (
(3.0 - (gv / gam_1) ** 2) * (3.0 - (gv / gam_2) ** 2)
) / (9.0 * gv * ((gv * tau_lim) ** 2 - 1.0))
# ---------- Temperature profile (eq 76): returns T^4 ----------
T_lev = (3.0 * T_int**4 / 4.0) * (tau_ir + A_pf + B_pf * jnp.exp(-tau_ir / tau_lim)) \
+ (3.0 * T_eq**4 / 4.0) * 4.0*f * (C_pf + D_pf * jnp.exp(-tau_ir / tau_lim) + E_pf * jnp.exp(-gv * tau_ir))
T_lev = T_lev**0.25
T_lay = 0.5 * (T_lev[:-1] + T_lev[1:])
return T_lev, T_lay
[docs]
def dry_convective_adjustment(T_lay: jnp.ndarray, p_lay: jnp.ndarray, p_lev: jnp.ndarray, kappa: float, max_iter: int = 10, tol: float = 1e-6) -> jnp.ndarray:
"""Apply dry convective adjustment to a temperature profile.
This function iteratively adjusts a temperature profile to ensure it is
convectively stable, preserving total enthalpy.
Parameters
----------
T_lay : `~jax.numpy.ndarray`, shape (nlay,)
Initial layer temperatures in Kelvin.
p_lay : `~jax.numpy.ndarray`, shape (nlay,)
Layer pressures.
p_lev : `~jax.numpy.ndarray`, shape (nlay+1,)
Level pressures.
kappa : float
Adiabatic index (R/cp).
max_iter : int, optional
Maximum number of adjustment iterations.
tol : float, optional
Tolerance for the stability check.
Returns
-------
T_lay_adj : `~jax.numpy.ndarray`, shape (nlay,)
Convectively adjusted layer temperature profile in Kelvin.
"""
nlay = T_lay.shape[0]
# Calculate pressure differences (layer thicknesses)
d_p = p_lev[1:] - p_lev[:-1]
def adjust_pair(T_work, i1, i2):
"""Adjust a pair of layers if convectively unstable."""
pfact = (p_lay[i1] / p_lay[i2]) ** kappa
# Check convective stability: T(i) should be >= T(i+1) * pfact
is_unstable = T_work[i1] < (T_work[i2] * pfact - tol)
# Mass-weighted average temperature
Tbar = (d_p[i1] * T_work[i1] + d_p[i2] * T_work[i2]) / (d_p[i1] + d_p[i2])
# New temperatures after adjustment (conserves enthalpy)
T_new_i2 = (d_p[i1] + d_p[i2]) * Tbar / (d_p[i2] + pfact * d_p[i1])
T_new_i1 = T_new_i2 * pfact
# Update only if unstable
T_updated = jnp.where(
is_unstable,
T_work.at[i1].set(T_new_i1).at[i2].set(T_new_i2),
T_work
)
return T_updated
def single_iteration(T_curr, _):
"""One full iteration: downward pass + upward pass."""
# Downward pass (from top to bottom: i=0 to nlay-2)
def downward_body(i, T_work):
return adjust_pair(T_work, i, i + 1)
T_after_down = jax.lax.fori_loop(0, nlay - 1, downward_body, T_curr)
# Upward pass (from bottom to top: i=nlay-2 to 0)
def upward_body(i, T_work):
idx = nlay - 2 - i
return adjust_pair(T_work, idx, idx + 1)
T_after_up = jax.lax.fori_loop(0, nlay - 1, upward_body, T_after_down)
return T_after_up, None
# Run max_iter iterations (no early exit in JAX scan)
T_adjusted, _ = jax.lax.scan(single_iteration, T_lay, None, length=max_iter)
return T_adjusted