"""
rate_jax.py
===========
"""
from __future__ import annotations
from typing import Dict, Mapping, Tuple, Union, Optional
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import optimistix as optx
# Gas constant [J/(mol·K)]
R_GAS = 8.314462618
# ============================================================================
# Global cache for NASA-9 thermo data
# ============================================================================
_NASA9_CACHE: Optional["NASA9ThermoJAX"] = None
Species = Tuple[str, ...]
__all__ = [
"NASA9ThermoJAX",
"load_nasa9_cache",
"get_nasa9_cache",
"clear_nasa9_cache",
"is_nasa9_cache_loaded",
"RateJAX",
]
[docs]
class NASA9ThermoJAX:
"""JAX-friendly NASA-9 thermo evaluator.
Stores per-species NASA-9 polynomial coefficients and evaluates the
dimensionless Gibbs free energy `G/(R T)` on-the-fly in JAX (no pre-tabulation).
"""
def __init__(self, data: Mapping[str, Mapping[str, jnp.ndarray]]):
self.data = data
[docs]
def g_over_RT(self, spec: str, T: jnp.ndarray) -> jnp.ndarray:
"""Evaluate the dimensionless Gibbs free energy `G/(R T)` from NASA-9 polynomials.
Parameters
----------
spec : str
Species key in the Gibbs table (e.g., `"H2O"`).
T : `~jax.numpy.ndarray`
Temperature in Kelvin.
Returns
-------
g_over_RT : `~jax.numpy.ndarray`
Dimensionless Gibbs free energy `G/(R T)` evaluated at `T`.
"""
d = self.data[spec]
coeffs_low = d["coeffs_low"]
coeffs_high = d["coeffs_high"]
t_switch = d["t_switch"]
t_min = d["t_min"]
t_max = d["t_max"]
T = jnp.asarray(T)
T = jnp.clip(T, t_min, t_max)
def _h_over_RT(coeffs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
a1, a2, a3, a4, a5, a6, a7, a8, a9, _a10 = coeffs
t2 = t * t
t3 = t2 * t
t4 = t3 * t
t5 = t4 * t
return (
-a1 / (t * t)
+ (a2 * jnp.log(t)) / t
+ a3
+ a4 * t / 2.0
+ a5 * t2 / 3.0
+ a6 * t3 / 4.0
+ a7 * t4 / 5.0
+ a8 * t5 / 6.0
+ a9 / t
)
def _s_over_R(coeffs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray:
a1, a2, a3, a4, a5, a6, a7, a8, _a9, a10 = coeffs
t2 = t * t
t3 = t2 * t
t4 = t3 * t
t5 = t4 * t
return (
-0.5 * a1 / (t * t)
- a2 / t
+ a3 * jnp.log(t)
+ a4 * t
+ a5 * t2 / 2.0
+ a6 * t3 / 3.0
+ a7 * t4 / 4.0
+ a8 * t5 / 5.0
+ a10
)
h_low = _h_over_RT(coeffs_low, T)
s_low = _s_over_R(coeffs_low, T)
h_high = _h_over_RT(coeffs_high, T)
s_high = _s_over_R(coeffs_high, T)
use_low = T < t_switch
h = jnp.where(use_low, h_low, h_high)
s = jnp.where(use_low, s_low, s_high)
return h - s
# ============================================================================
# Global cache management functions
# ============================================================================
[docs]
def load_nasa9_cache(nasa9_dir: str) -> NASA9ThermoJAX:
"""Load NASA-9 polynomial coefficient files into the global NASA-9 cache.
This function should be called once during initialization (e.g., in
run_retrieval.py) to load and cache the Gibbs free energy data before
running forward models or retrievals.
Parameters
----------
nasa9_dir : str
Directory containing NASA-9 coefficient files
Returns
-------
nasa9 : NASA9ThermoJAX
The loaded NASA-9 thermo table (also cached globally)
"""
global _NASA9_CACHE
data: Dict[str, Dict[str, jnp.ndarray]] = {}
t_min = 200.0
t_max = 6000.0
t_switch = 1000.0
def _read_nasa9_coeffs(path: str) -> Tuple[np.ndarray, np.ndarray, float]:
"""Read NASA-9 coefficients from a file.
Supported formats:
- 20 numbers: low10 + high10
- 23 numbers: Tmin, Tmid, Tmax + low10 + high10
Lines may contain comments starting with '#' or '!'.
"""
with open(path, "r", encoding="utf-8") as f:
raw_lines = f.read().splitlines()
cleaned: list[str] = []
for line in raw_lines:
line = line.split("#", 1)[0]
line = line.split("!", 1)[0]
cleaned.append(line)
raw = "\n".join(cleaned).replace("D", "E")
coeffs = np.fromstring(raw, sep=" ")
t_switch_local = t_switch
if coeffs.size == 23:
_tmin, tmid, _tmax = coeffs[:3]
t_switch_local = float(tmid)
coeffs = coeffs[3:]
# Some sources include an additional high-temperature range, providing
# three blocks of 10 coefficients (30 total). For Exo_Skryer we keep the
# first two ranges (low, high) and ignore the third.
if coeffs.size == 30:
coeffs = coeffs[:20]
if coeffs.size != 20:
raise ValueError(
f"Expected 20 NASA-9 coefficients in {path} (or 23 including Tmin/Tmid/Tmax), got {coeffs.size}."
)
return coeffs[:10], coeffs[10:], t_switch_local
for fname in os.listdir(nasa9_dir):
if not fname.endswith(".txt"):
continue
path = os.path.join(nasa9_dir, fname)
molname = os.path.splitext(fname)[0]
coeffs_low_np, coeffs_high_np, t_switch_local = _read_nasa9_coeffs(path)
data[molname] = {
"coeffs_low": jnp.asarray(coeffs_low_np),
"coeffs_high": jnp.asarray(coeffs_high_np),
"t_switch": jnp.asarray(t_switch_local),
"t_min": jnp.asarray(t_min),
"t_max": jnp.asarray(t_max),
}
_NASA9_CACHE = NASA9ThermoJAX(data)
return _NASA9_CACHE
[docs]
def get_nasa9_cache() -> NASA9ThermoJAX:
"""Return the cached NASA-9 thermo table.
Returns
-------
nasa9 : NASA9ThermoJAX
The cached NASA-9 thermo table.
Raises
------
RuntimeError
If cache has not been initialized with load_nasa9_cache()
"""
if _NASA9_CACHE is None:
raise RuntimeError(
"NASA-9 cache not initialized. Call load_nasa9_cache() first."
)
return _NASA9_CACHE
[docs]
def clear_nasa9_cache() -> None:
"""Clear the cached NASA-9 thermo table.
Useful for freeing memory or reloading with different data.
"""
global _NASA9_CACHE
_NASA9_CACHE = None
[docs]
def is_nasa9_cache_loaded() -> bool:
"""Return `True` if the NASA-9 cache is loaded.
Returns
-------
loaded : bool
True if cache is loaded, False otherwise.
"""
return _NASA9_CACHE is not None
[docs]
class RateJAX:
"""RATE-style thermochemical equilibrium solver implemented in JAX.
This class computes equilibrium abundances for a reduced H/C/N/O chemistry
network over a 1D (T, p) profile. It is designed to be usable inside
JIT-compiled forward models:
- Uses `jax.vmap` to solve each layer independently
- Avoids SciPy; uses `optimistix` for root finding where needed
- Returns a VMR dictionary keyed by species name
Parameters
----------
thermo : `~exo_skryer.rate_jax.NASA9ThermoJAX`
NASA-9 thermo evaluator created by `load_nasa9_cache`.
C, N, O : float
Elemental abundances (number ratios relative to H₂, following the original
RATE conventions).
fHe : float
Helium fraction factor used to compute He from H-bearing species.
"""
def __init__(
self,
thermo: NASA9ThermoJAX,
C: float = 2.5e-4,
N: float = 1.0e-4,
O: float = 5.0e-4,
fHe: float = 0.0,
):
self.thermo = thermo
# Keep as JAX arrays for JIT compatibility (don't convert to Python float)
self.C = C
self.N = N
self.O = O
self.fHe = fHe
self.species: Species = (
"H2O", "CH4", "CO", "CO2", "NH3",
"C2H2", "C2H4", "HCN", "N2",
"H2", "H", "He",
)
# ---------- Thermo wrappers ----------
[docs]
def g_over_RT(self, spec: str, T: jnp.ndarray) -> jnp.ndarray:
return self.thermo.g_over_RT(spec, T)
# ---------- Equilibrium constants k' ----------
[docs]
def kprime0(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for hydrogen dissociation: H2 ↔ 2H
K'₀ = exp(-ΔG/RT) / p
where ΔG = 2·G(H) - G(H₂)
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₀ : array
Modified equilibrium constant [bar⁻¹]
"""
return jnp.exp(-(2.0 * self.g_over_RT("H", T) - self.g_over_RT("H2", T))) / p
[docs]
def kprime1(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for methane-water reaction: CH₄ + H₂O ↔ CO + 3H₂
K'₁ = exp(-ΔG/RT) / p²
where ΔG = G(CO) + 3·G(H₂) - G(CH₄) - G(H₂O)
This is the key reaction controlling the C/O ratio in hot atmospheres.
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₁ : array
Modified equilibrium constant [bar⁻²]
"""
return jnp.exp(
-(
self.g_over_RT("CO", T) + 3.0 * self.g_over_RT("H2", T)
- self.g_over_RT("CH4", T) - self.g_over_RT("H2O", T)
)
) / p**2
[docs]
def kprime2(self, T: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for carbon dioxide reduction: CO₂ + H₂ ↔ CO + H₂O
K'₂ = exp(-ΔG/RT)
where ΔG = G(CO) + G(H₂O) - G(CO₂) - G(H₂)
Parameters
----------
T : array
Temperature [K]
Returns
-------
K'₂ : array
Equilibrium constant [dimensionless]
"""
return jnp.exp(
-(
self.g_over_RT("CO", T) + self.g_over_RT("H2O", T)
- self.g_over_RT("CO2", T) - self.g_over_RT("H2", T)
)
)
[docs]
def kprime3(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for acetylene formation: 2CH₄ ↔ C₂H₂ + 3H₂
K'₃ = exp(-ΔG/RT) / p²
where ΔG = G(C₂H₂) + 3·G(H₂) - 2·G(CH₄)
Important for high-C/O and high-temperature atmospheres.
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₃ : array
Modified equilibrium constant [bar⁻²]
"""
return jnp.exp(
-(
self.g_over_RT("C2H2", T) + 3.0 * self.g_over_RT("H2", T)
- 2.0 * self.g_over_RT("CH4", T)
)
) / p**2
[docs]
def kprime4(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for ethylene-acetylene: C₂H₄ ↔ C₂H₂ + H₂
K'₄ = exp(-ΔG/RT) / p
where ΔG = G(C₂H₂) + G(H₂) - G(C₂H₄)
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₄ : array
Modified equilibrium constant [bar⁻¹]
"""
return jnp.exp(
-(
self.g_over_RT("C2H2", T) + self.g_over_RT("H2", T)
- self.g_over_RT("C2H4", T)
)
) / p
[docs]
def kprime5(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for ammonia dissociation: 2NH₃ ↔ N₂ + 3H₂
K'₅ = exp(-ΔG/RT) / p²
where ΔG = G(N₂) + 3·G(H₂) - 2·G(NH₃)
Dominant nitrogen chemistry reaction in hot atmospheres.
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₅ : array
Modified equilibrium constant [bar⁻²]
"""
return jnp.exp(
-(
self.g_over_RT("N2", T) + 3.0 * self.g_over_RT("H2", T)
- 2.0 * self.g_over_RT("NH3", T)
)
) / p**2
[docs]
def kprime6(self, T: jnp.ndarray, p: jnp.ndarray) -> jnp.ndarray:
"""
Equilibrium constant for HCN formation: NH₃ + CH₄ ↔ HCN + 3H₂
K'₆ = exp(-ΔG/RT) / p²
where ΔG = G(HCN) + 3·G(H₂) - G(NH₃) - G(CH₄)
Important when both N and C are abundant at high temperatures.
Parameters
----------
T : array
Temperature [K]
p : array
Pressure [bar]
Returns
-------
K'₆ : array
Modified equilibrium constant [bar⁻²]
"""
return jnp.exp(
-(
self.g_over_RT("HCN", T) + 3.0 * self.g_over_RT("H2", T)
- self.g_over_RT("NH3", T) - self.g_over_RT("CH4", T)
)
) / p**2
# ---------- Turnover pressure (CO vs H2O dominated) ----------
[docs]
@staticmethod
def top(T: jnp.ndarray, C: float, N: float, O: float) -> jnp.ndarray:
"""
Turnover pressure: transition between CO-dominated and H2O-dominated chemistry.
Computes the pressure where CO and H2O abundances become comparable,
based on a polynomial fit to thermochemical equilibrium calculations
(Lodders & Fegley 2002).
Parameters
----------
T : array
Temperature [K]
C : float
Carbon elemental abundance (number ratio relative to H2)
N : float
Nitrogen elemental abundance (number ratio relative to H2)
O : float
Oxygen elemental abundance (number ratio relative to H2)
Returns
-------
P_turnover : array
Turnover pressure [bar], where CO/H2O ~ 1
"""
# Polynomial coefficients organized by variable
# Structure: constant + T^1..4 + C^1..4 + N^1..4 + O^1..4
const = -1.07028658e+03
coeff_T = jnp.array([1.20815018e+03, -5.21868655e+02, 1.02459233e+02, -7.68350388e+00])
coeff_C = jnp.array([1.30787500e+00, 3.18619604e-01, 5.32918135e-02, 3.12269845e-03])
coeff_N = jnp.array([2.81238906e-02, 1.26015039e-02, 2.07616221e-03, 1.16038224e-04])
coeff_O = jnp.array([-1.69589064e-01, -5.21662503e-02, -7.33669631e-03, -3.74492912e-04])
# Compute log10 of input variables
logT = jnp.log10(T)
logC = jnp.log10(C)
logN = jnp.log10(N)
logO = jnp.log10(O)
# Powers array [1, 2, 3, 4] for vectorized exponentiation
powers = jnp.array([1.0, 2.0, 3.0, 4.0])
# Vectorized polynomial evaluation using broadcasting
# For T: compute [logT^1, logT^2, logT^3, logT^4] then dot with coefficients
# Note: logT can be array (batched), so we use outer-like broadcasting
T_contrib = jnp.sum(coeff_T * logT[..., None] ** powers, axis=-1)
C_contrib = jnp.sum(coeff_C * logC ** powers)
N_contrib = jnp.sum(coeff_N * logN ** powers)
O_contrib = jnp.sum(coeff_O * logO ** powers)
log10_P_turn = const + T_contrib + C_contrib + N_contrib + O_contrib
# Clip to valid pressure range: 10^-8 to 10^3 bar
log10_P_turn = jnp.clip(log10_P_turn, -8.0001, 3.0001)
return 10.0 ** log10_P_turn
# ---------- Polynomial builders (example + pattern) ----------
[docs]
def HCO_poly6_CO(self, f, k1, k2, k3, k4):
"""
HCO chemistry, polynomial in CO.
Now returns 7 coefficients (last one is 0.0) for JAX compatibility.
"""
C, O = self.C, self.O
A0 = -C * O**2 * f**3 * k1**2 * k2**3 * k4
A1 = (
-C * O**2 * f**3 * k1**2 * k2**2 * k4
+ 2 * C * O * f**2 * k1**2 * k2**3 * k4
+ O**3 * f**3 * k1**2 * k2**2 * k4
+ O**2 * f**2 * k1**2 * k2**3 * k4
+ O * f * k1 * k2**3 * k4
)
A2 = (
2 * C * O * f**2 * k1**2 * k2**2 * k4
- C * f * k1**2 * k2**3 * k4
- 2 * O**2 * f**2 * k1**2 * k2**2 * k4
- 2 * O * f * k1**2 * k2**3 * k4
+ 2 * O * f * k1 * k2**2 * k4
- k1 * k2**3 * k4
+ 2 * k2**3 * k3 * k4
+ 2 * k2**3 * k3
)
A3 = (
-C * f * k1**2 * k2**2 * k4
+ O * f * k1**2 * k2**2 * k4
+ O * f * k1 * k2 * k4
+ k1**2 * k2**3 * k4
- 2 * k1 * k2**2 * k4
+ 6 * k2**2 * k3 * k4
+ 6 * k2**2 * k3
)
A4 = -k1 * k2 * k4 + 6 * k2 * k3 * k4 + 6 * k2 * k3
A5 = 2 * k3 * k4 + 2 * k3
A6 = 0.0 # pad to degree-6 polynomial
return jnp.array([A0, A1, A2, A3, A4, A5, A6])
[docs]
def HCO_poly6_H2O(self, f, k1, k2, k3, k4):
"""
HCO chemistry, polynomial in H2O.
Now returns 7 coefficients (last one is 0.0) for JAX compatibility.
"""
C, O = self.C, self.O
A0 = 2 * O**2 * f**2 * k2**2 * k3 * k4 + 2 * O**2 * f**2 * k2**2 * k3
A1 = O * f * k1 * k2**2 * k4 - 4 * O * f * k2**2 * k3 * k4 - 4 * O * f * k2**2 * k3
A2 = (
-C * f * k1**2 * k2**2 * k4
+ O * f * k1**2 * k2**2 * k4
+ O * f * k1 * k2 * k4
- k1 * k2**2 * k4
+ 2 * k2**2 * k3 * k4
+ 2 * k2**2 * k3
)
A3 = (
-2 * C * f * k1**2 * k2 * k4
+ 2 * O * f * k1**2 * k2 * k4
- k1**2 * k2**2 * k4
- k1 * k2 * k4
)
A4 = -C * f * k1**2 * k4 + O * f * k1**2 * k4 - 2 * k1**2 * k2 * k4
A5 = -k1**2 * k4
A6 = 0.0 # pad to degree-6 polynomial
return jnp.array([A0, A1, A2, A3, A4, A5, A6])
# ---------- HCNO, polynomial in CO ----------
[docs]
def HCNO_poly8_CO(self, f, k1, k2, k3, k4, k5, k6):
"""
JAX version of original HCNO_poly8_CO (CO is the root variable).
"""
C, N, O = self.C, self.N, self.O
A0 = 2 * C**2 * O**4 * f**6 * k1**4 * k4**2 * k5
A1 = -C * O**3 * f**4 * k1**3 * k4**2 * (
8 * C * f * k1 * k5 + 4 * O * f * k1 * k5 + 4 * k5 - k6
)
A2 = (
O**2 * f**2 * k1**2 * k4 * (
12 * C**2 * f**2 * k1**2 * k4 * k5
+ 16 * C * O * f**2 * k1**2 * k4 * k5
+ 12 * C * f * k1 * k4 * k5
- 3 * C * f * k1 * k4 * k6
- 8 * C * f * k3 * k4 * k5
- 8 * C * f * k3 * k5
+ C * f * k4 * k6**2
- N * f * k4 * k6**2
+ 2 * O**2 * f**2 * k1**2 * k4 * k5
+ 4 * O * f * k1 * k4 * k5
- O * f * k1 * k4 * k6
+ 2 * k4 * k5
- k4 * k6
)
)
A3 = -O * f * k1 * k4 * (
8 * C**2 * f**2 * k1**3 * k4 * k5
+ 24 * C * O * f**2 * k1**3 * k4 * k5
+ 12 * C * f * k1**2 * k4 * k5
- 3 * C * f * k1**2 * k4 * k6
- 16 * C * f * k1 * k3 * k4 * k5
- 16 * C * f * k1 * k3 * k5
+ 2 * C * f * k1 * k4 * k6**2
- 2 * N * f * k1 * k4 * k6**2
+ 8 * O**2 * f**2 * k1**3 * k4 * k5
+ 12 * O * f * k1**2 * k4 * k5
- 3 * O * f * k1**2 * k4 * k6
- 8 * O * f * k1 * k3 * k4 * k5
- 8 * O * f * k1 * k3 * k5
+ O * f * k1 * k4 * k6**2
+ 4 * k1 * k4 * k5
- 2 * k1 * k4 * k6
- 8 * k3 * k4 * k5
+ 2 * k3 * k4 * k6
- 8 * k3 * k5
+ 2 * k3 * k6
+ k4 * k6**2
)
A4 = (
2 * C**2 * f**2 * k1**4 * k4**2 * k5
+ 16 * C * O * f**2 * k1**4 * k4**2 * k5
+ 4 * C * f * k1**3 * k4**2 * k5
- C * f * k1**3 * k4**2 * k6
- 8 * C * f * k1**2 * k3 * k4**2 * k5
- 8 * C * f * k1**2 * k3 * k4 * k5
+ C * f * k1**2 * k4**2 * k6**2
- N * f * k1**2 * k4**2 * k6**2
+ 12 * O**2 * f**2 * k1**4 * k4**2 * k5
+ 12 * O * f * k1**3 * k4**2 * k5
- 3 * O * f * k1**3 * k4**2 * k6
- 16 * O * f * k1**2 * k3 * k4**2 * k5
- 16 * O * f * k1**2 * k3 * k4 * k5
+ 2 * O * f * k1**2 * k4**2 * k6**2
+ 2 * k1**2 * k4**2 * k5
- k1**2 * k4**2 * k6
- 8 * k1 * k3 * k4**2 * k5
+ 2 * k1 * k3 * k4**2 * k6
- 8 * k1 * k3 * k4 * k5
+ 2 * k1 * k3 * k4 * k6
+ k1 * k4**2 * k6**2
+ 8 * k3**2 * k4**2 * k5
+ 16 * k3**2 * k4 * k5
+ 8 * k3**2 * k5
- 2 * k3 * k4**2 * k6**2
- 2 * k3 * k4 * k6**2
)
A5 = -k1**2 * k4 * (
4 * C * f * k1**2 * k4 * k5
+ 8 * O * f * k1**2 * k4 * k5
+ 4 * k1 * k4 * k5
- k1 * k4 * k6
- 8 * k3 * k4 * k5
- 8 * k3 * k5
+ k4 * k6**2
)
A6 = 2 * k1**4 * k4**2 * k5
return jnp.array([A0, A1, A2, A3, A4, A5, A6])
# ---------- HCNO, polynomial in H2O ----------
[docs]
def HCNO_poly8_H2O(self, f, k1, k2, k3, k4, k5, k6):
"""
JAX version of original HCNO_poly8_H2O (H2O is the root variable).
"""
C, N, O = self.C, self.N, self.O
A0 = 2 * O**4 * f**4 * k3 * (k4 + 1.0) * (4 * k3 * k4 * k5 + 4 * k3 * k5 - k4 * k6**2)
A1 = O**3 * f**3 * (
8 * k1 * k3 * k4**2 * k5
- 2 * k1 * k3 * k4**2 * k6
+ 8 * k1 * k3 * k4 * k5
- 2 * k1 * k3 * k4 * k6
- k1 * k4**2 * k6**2
- 32 * k3**2 * k4**2 * k5
- 64 * k3**2 * k4 * k5
- 32 * k3**2 * k5
+ 8 * k3 * k4**2 * k6**2
+ 8 * k3 * k4 * k6**2
)
A2 = -O**2 * f**2 * (
8 * C * f * k1**2 * k3 * k4**2 * k5
+ 8 * C * f * k1**2 * k3 * k4 * k5
- C * f * k1**2 * k4**2 * k6**2
+ N * f * k1**2 * k4**2 * k6**2
- 8 * O * f * k1**2 * k3 * k4**2 * k5
- 8 * O * f * k1**2 * k3 * k4 * k5
+ O * f * k1**2 * k4**2 * k6**2
- 2 * k1**2 * k4**2 * k5
+ k1**2 * k4**2 * k6
+ 24 * k1 * k3 * k4**2 * k5
- 6 * k1 * k3 * k4**2 * k6
+ 24 * k1 * k3 * k4 * k5
- 6 * k1 * k3 * k4 * k6
- 3 * k1 * k4**2 * k6**2
- 48 * k3**2 * k4**2 * k5
- 96 * k3**2 * k4 * k5
- 48 * k3**2 * k5
+ 12 * k3 * k4**2 * k6**2
+ 12 * k3 * k4 * k6**2
)
A3 = -O * f * (
4 * C * f * k1**3 * k4**2 * k5
- C * f * k1**3 * k4**2 * k6
- 16 * C * f * k1**2 * k3 * k4**2 * k5
- 16 * C * f * k1**2 * k3 * k4 * k5
+ 2 * C * f * k1**2 * k4**2 * k6**2
- 2 * N * f * k1**2 * k4**2 * k6**2
- 4 * O * f * k1**3 * k4**2 * k5
+ O * f * k1**3 * k4**2 * k6
+ 24 * O * f * k1**2 * k3 * k4**2 * k5
+ 24 * O * f * k1**2 * k3 * k4 * k5
- 3 * O * f * k1**2 * k4**2 * k6**2
+ 4 * k1**2 * k4**2 * k5
- 2 * k1**2 * k4**2 * k6
- 24 * k1 * k3 * k4**2 * k5
+ 6 * k1 * k3 * k4**2 * k6
- 24 * k1 * k3 * k4 * k5
+ 6 * k1 * k3 * k4 * k6
+ 3 * k1 * k4**2 * k6**2
+ 32 * k3**2 * k4**2 * k5
+ 64 * k3**2 * k4 * k5
+ 32 * k3**2 * k5
- 8 * k3 * k4**2 * k6**2
- 8 * k3 * k4 * k6**2
)
A4 = (
2 * C**2 * f**2 * k1**4 * k4**2 * k5
- 4 * C * O * f**2 * k1**4 * k4**2 * k5
+ 4 * C * f * k1**3 * k4**2 * k5
- C * f * k1**3 * k4**2 * k6
- 8 * C * f * k1**2 * k3 * k4**2 * k5
- 8 * C * f * k1**2 * k3 * k4 * k5
+ C * f * k1**2 * k4**2 * k6**2
- N * f * k1**2 * k4**2 * k6**2
+ 2 * O**2 * f**2 * k1**4 * k4**2 * k5
- 8 * O * f * k1**3 * k4**2 * k5
+ 2 * O * f * k1**3 * k4**2 * k6
+ 24 * O * f * k1**2 * k3 * k4**2 * k5
+ 24 * O * f * k1**2 * k3 * k4 * k5
- 3 * O * f * k1**2 * k4**2 * k6**2
+ 2 * k1**2 * k4**2 * k5
- k1**2 * k4**2 * k6
- 8 * k1 * k3 * k4**2 * k5
+ 2 * k1 * k3 * k4**2 * k6
- 8 * k1 * k3 * k4 * k5
+ 2 * k1 * k3 * k4 * k6
+ k1 * k4**2 * k6**2
+ 8 * k3**2 * k4**2 * k5
+ 16 * k3**2 * k4 * k5
+ 8 * k3**2 * k5
- 2 * k3 * k4**2 * k6**2
- 2 * k3 * k4 * k6**2
)
A5 = k1**2 * k4 * (
4 * C * f * k1**2 * k4 * k5
- 4 * O * f * k1**2 * k4 * k5
+ 4 * k1 * k4 * k5
- k1 * k4 * k6
- 8 * k3 * k4 * k5
- 8 * k3 * k5
+ k4 * k6**2
)
A6 = 2 * k1**4 * k4**2 * k5
return jnp.array([A0, A1, A2, A3, A4, A5, A6])
# ---------- Newton–Raphson (bounded) using Optimistix ----------
@staticmethod
def _eval_poly(A: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
"""
Evaluate polynomial at x.
Parameters
----------
A : array, shape (n,)
Polynomial coefficients, A[0] = constant, A[-1] = highest degree
x : scalar array
Point at which to evaluate polynomial
Returns
-------
p(x) : scalar array
Polynomial value at x
"""
return jnp.polyval(A[::-1], x) # polyval expects highest degree first
[docs]
@classmethod
def newton_raphson_bounded(
cls,
A: jnp.ndarray,
guess: float,
vmax: float,
xtol: float = 1e-10,
imax: int = 80,
kmax: int = 10,
) -> jnp.ndarray:
"""
Robust polynomial root finding with bounded domain using Optimistix.
Uses Newton's method with automatic differentiation. Tries multiple
initial guesses with decreasing scales if needed, then clamps result
to [0, vmax].
Parameters
----------
A : array, shape (n,)
Polynomial coefficients (constant to highest degree)
guess : float
Initial guess for root
vmax : float
Maximum valid value for root
xtol : float
Relative/absolute tolerance for convergence
imax : int
Maximum iterations per attempt
kmax : int
Maximum number of retry attempts with scaled guesses
Returns
-------
root : scalar array
Root of polynomial, clamped to [0, vmax]
"""
# Optimistix tolerances should be set in line with dtype. If x64 is disabled,
# JAX will use float32, and tighter tolerances than ~1e-6 are ineffective.
xtol_eff = float(xtol)
if A.dtype == jnp.float32:
xtol_eff = max(xtol_eff, 1e-6)
# Success sentinel for Optimistix root-finding.
_RESULT_SUCCESS = optx.RESULTS._name_to_item["successful"]
def poly_fn(x, _args):
"""Function to find root of: polynomial(x) = 0"""
return cls._eval_poly(A, x)
def try_solve(guess_scaled: float) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Try to solve from a given initial guess."""
solver = optx.Newton(rtol=xtol_eff, atol=xtol_eff)
sol = optx.root_find(
poly_fn,
solver,
y0=guess_scaled,
args=None,
max_steps=imax,
throw=False, # Don't raise on convergence failure
)
# Accept either explicit success or sufficiently small residual.
fval = sol.state.f
ok = (sol.result == _RESULT_SUCCESS) | (jnp.abs(fval) <= 10.0 * xtol_eff)
return sol.value, ok
# Try multiple initial guesses with decreasing scales
def attempt_body(carry):
root, ok, attempt = carry
scale = 10.0 ** (-attempt)
guess_scaled = guess * scale
root_candidate, ok_candidate = try_solve(guess_scaled)
return (root_candidate, ok_candidate, attempt + 1)
def attempt_cond(carry):
root, ok, attempt = carry
finite = jnp.isfinite(root)
in_range = jnp.logical_and(root >= 0.0, root <= vmax)
good = jnp.logical_and(ok, jnp.logical_and(finite, in_range))
more_attempts = attempt < kmax
# keep trying while root is bad and we still have attempts left
return jnp.logical_and(~good, more_attempts)
# Start with bogus root so we always do at least one attempt
carry0 = (jnp.array(-1.0), jnp.array(False), jnp.int32(0))
root_final, ok_final, _ = lax.while_loop(attempt_cond, attempt_body, carry0)
def _bisect_root() -> jnp.ndarray:
"""Bisection fallback on [0, vmax] if a sign change exists."""
lo = jnp.array(0.0, dtype=A.dtype)
hi = jnp.asarray(vmax, dtype=A.dtype)
flo = cls._eval_poly(A, lo)
fhi = cls._eval_poly(A, hi)
n_iter = 60 if A.dtype == jnp.float64 else 40
def body(_, state):
lo, hi, flo, fhi = state
mid = 0.5 * (lo + hi)
fmid = cls._eval_poly(A, mid)
# If fmid is NaN/Inf, shrink interval conservatively.
bad = ~jnp.isfinite(fmid)
go_left = jnp.where(bad, True, flo * fmid <= 0.0)
lo2 = jnp.where(go_left, lo, mid)
hi2 = jnp.where(go_left, mid, hi)
flo2 = jnp.where(go_left, flo, fmid)
fhi2 = jnp.where(go_left, fmid, fhi)
return lo2, hi2, flo2, fhi2
lo, hi, _, _ = lax.fori_loop(0, n_iter, body, (lo, hi, flo, fhi))
return 0.5 * (lo + hi)
def _fallback_root() -> jnp.ndarray:
# If Newton fails entirely (NaN or no acceptable candidate), return a safe interior point.
return jnp.asarray(0.5, dtype=A.dtype) * jnp.asarray(vmax, dtype=A.dtype)
# If Newton succeeded and produced a finite in-range root, use it. Otherwise try bisection if possible.
root_final = jnp.clip(root_final, 0.0, vmax)
newton_good = ok_final & jnp.isfinite(root_final)
f0 = cls._eval_poly(A, jnp.array(0.0, dtype=A.dtype))
f1 = cls._eval_poly(A, jnp.asarray(vmax, dtype=A.dtype))
sign_change = jnp.isfinite(f0) & jnp.isfinite(f1) & (f0 * f1 <= 0.0)
return lax.cond(
newton_good,
lambda: root_final,
lambda: lax.cond(sign_change, _bisect_root, _fallback_root),
)
# ---------- Rest of species from H2O & CO ----------
[docs]
def solve_rest(
self,
H2O: float,
CO: float,
f: float,
k1: float,
k2: float,
k3: float,
k4: float,
k5: float,
k6: float,
) -> jnp.ndarray:
"""
JAX version of solve_rest for a single layer.
Returns [H2O, CH4, CO, CO2, NH3, C2H2, C2H4, HCN, N2]
"""
eps = 1e-300
k1_safe = k1 + eps
k2_safe = k2 + eps
k4_safe = k4 + eps
k5_safe = k5 + eps
H2O_safe = H2O + eps
CH4 = CO / (k1_safe * H2O_safe)
CO2 = CO * H2O_safe / k2_safe
C2H2 = k3 * CH4**2
C2H4 = C2H2 / k4_safe
# Quadratic for NH3:
b = 1.0 + k6 * CH4
disc = b**2 + 8.0 * f * k5_safe * self.N
NH3 = (jnp.sqrt(disc) - b) / (4.0 * k5_safe)
# Use approximation when 8 f k5 N / b^2 << 1:
small_param = 8.0 * f * k5_safe * self.N / (b**2 + 1e-30)
NH3_approx = f * self.N
NH3 = jnp.where(small_param < 1e-6, NH3_approx, NH3)
HCN = k6 * NH3 * CH4
N2 = k5_safe * NH3**2
return jnp.array([H2O, CH4, CO, CO2, NH3, C2H2, C2H4, HCN, N2])
# ---------- Choose polynomial index (0..3) per layer ----------
def _choose_poly_index(self, T_i: float, p_i: float):
"""
Encodes the logic:
0 -> HCO_poly6_CO
1 -> HCO_poly6_H2O
2 -> HCNO_poly8_CO
3 -> HCNO_poly8_H2O
Implemented with lax.cond so it works under jit.
Returns a JAX int32 scalar.
"""
C, N, O = self.C, self.N, self.O
C_over_O = C / O
N_over_C = N / C
def branch_CO_lt1(_):
# C/O < 1
cond_N_hot = jnp.logical_and(N_over_C > 10.0, T_i > 2200.0)
def when_N_hot(_):
def when_C_over_O_mid(_):
return jnp.int32(3) # HCNO H2O
def when_C_over_O_low(_):
return jnp.int32(2) # HCNO CO
return lax.cond(C_over_O > 0.1, when_C_over_O_mid, when_C_over_O_low, None)
def when_else(_):
return jnp.int32(0) # HCO CO
return lax.cond(cond_N_hot, when_N_hot, when_else, None)
def branch_CO_ge1(_):
# C/O >= 1
turn = RateJAX.top(T_i, C, N, O)
cond_lower = p_i > turn
def when_lower(_):
return jnp.int32(0) # HCO CO
def when_upper(_):
cond_N_hot2 = jnp.logical_and(N_over_C > 0.1, T_i > 900.0)
def when_N_hot2(_):
return jnp.int32(3) # HCNO H2O
def when_else2(_):
return jnp.int32(1) # HCO H2O
return lax.cond(cond_N_hot2, when_N_hot2, when_else2, None)
return lax.cond(cond_lower, when_lower, when_upper, None)
cond_CO_lt1 = C_over_O < 1.0
# IMPORTANT: DO NOT wrap this in Python int()
return lax.cond(cond_CO_lt1, branch_CO_lt1, branch_CO_ge1, None)
# ---------- Solve one layer ----------
def _solve_one_layer(
self,
T_i: float,
p_i: float,
f_i: float,
k1_i: float,
k2_i: float,
k3_i: float,
k4_i: float,
k5_i: float,
k6_i: float,
) -> jnp.ndarray:
C, O = self.C, self.O
# 0 -> HCO_poly6_CO
# 1 -> HCO_poly6_H2O
# 2 -> HCNO_poly8_CO
# 3 -> HCNO_poly8_H2O
poly_idx = self._choose_poly_index(T_i, p_i) # JAX int scalar
is_H2O_var = (poly_idx % 2 == 1)
# Build all four coefficient sets (each length 7 now)
A_HCO_CO = self.HCO_poly6_CO(f_i, k1_i, k2_i, k3_i, k4_i)
A_HCO_H2O = self.HCO_poly6_H2O(f_i, k1_i, k2_i, k3_i, k4_i)
A_HCNO_CO = self.HCNO_poly8_CO(f_i, k1_i, k2_i, k3_i, k4_i, k5_i, k6_i)
A_HCNO_H2O = self.HCNO_poly8_H2O(f_i, k1_i, k2_i, k3_i, k4_i, k5_i, k6_i)
A_all = jnp.stack([A_HCO_CO, A_HCO_H2O, A_HCNO_CO, A_HCNO_H2O], axis=0)
A = A_all[poly_idx] # shape (7,)
# Bounds for the root
vmax_H2O = f_i * O
vmax_CO = f_i * jnp.minimum(C, O)
vmax = jnp.where(is_H2O_var, vmax_H2O, vmax_CO)
guess = 0.99 * vmax
# More stable multi-guess NR:
root = self.newton_raphson_bounded(A, guess, vmax)
# Recover H2O and CO
H2O_from_CO = (f_i * O - root) / (1.0 + 2.0 * root / k2_i)
CO_from_H2O = (f_i * O - root) / (1.0 + 2.0 * root / k2_i)
H2O = jnp.where(is_H2O_var, root, H2O_from_CO)
CO = jnp.where(is_H2O_var, CO_from_H2O, root)
# Remaining species (normalized to H2):
return self.solve_rest(H2O, CO, f_i, k1_i, k2_i, k3_i, k4_i, k5_i, k6_i)
# ---------- Main public API: solve profile & return VMR dict ----------
[docs]
def solve_profile(
self,
T: jnp.ndarray,
p: jnp.ndarray,
return_diagnostics: bool = False,
) -> Union[Dict[str, jnp.ndarray], Tuple[Dict[str, jnp.ndarray], Dict]]:
"""
Solve thermochemical equilibrium across a 1D T-p profile.
Parameters
----------
T : 1D array [K]
Temperature profile
p : 1D array [bar]
Pressure profile
return_diagnostics : bool, optional
If True, return (vmr_dict, diagnostics) with convergence info
Returns
-------
vmr : dict[str, jnp.ndarray]
Keys: self.species, each value shape = (nlayers,)
diagnostics : dict, optional
Only returned if return_diagnostics=True. Contains:
- 'n_layers': number of layers
- 'T_range': (min, max) temperature
- 'p_range': (min, max) pressure
Raises
------
ValueError
If inputs have incompatible shapes or invalid values
"""
# Convert to arrays
T = jnp.asarray(T)
p = jnp.asarray(p)
# Shape validation (JIT-compatible using assertions on static shapes)
# Note: Value validation (T > 0, p > 0) should be done by caller when using JIT
if hasattr(T, 'shape') and hasattr(p, 'shape'):
# These checks work with concrete arrays (non-JIT)
if T.ndim != 1:
raise ValueError(f"Temperature must be 1D array, got {T.ndim}D")
if p.ndim != 1:
raise ValueError(f"Pressure must be 1D array, got {p.ndim}D")
if T.shape != p.shape:
raise ValueError(
f"Temperature and pressure must have same shape, "
f"got T.shape={T.shape} and p.shape={p.shape}"
)
nlayers = T.shape[0]
# Equilibrium constants (vectorized):
# Equilibrium constants (vectorized):
k0 = self.kprime0(T, p)
k1 = self.kprime1(T, p)
k2 = self.kprime2(T)
k3 = self.kprime3(T, p)
k4 = self.kprime4(T, p)
k5 = self.kprime5(T, p)
k6 = self.kprime6(T, p)
# Avoid exact 0/inf constants (which break algebra downstream).
# Use dtype-aware bounds to avoid float32 overflow/underflow when x64 is disabled.
if k0.dtype == jnp.float64:
k_min, k_max = 1e-300, 1e300
else:
finfo = jnp.finfo(jnp.float32)
k_min, k_max = finfo.tiny, finfo.max
k0 = jnp.clip(k0, k_min, k_max)
k1 = jnp.clip(k1, k_min, k_max)
k2 = jnp.clip(k2, k_min, k_max)
k3 = jnp.clip(k3, k_min, k_max)
k4 = jnp.clip(k4, k_min, k_max)
k5 = jnp.clip(k5, k_min, k_max)
k6 = jnp.clip(k6, k_min, k_max)
# Hydrogen chemistry:
# Hatom and H2 from quadratic as in original code:
Hatom = (-1.0 + jnp.sqrt(1.0 + 8.0 / k0)) / (4.0 / k0)
Hmol = Hatom**2 / k0 # n(H2)
f = (Hatom + 2.0 * Hmol) / Hmol
# Solve heavy species per layer with vmap:
solve_layer_vmapped = jax.vmap(
self._solve_one_layer,
in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0),
)
# shape: (nlayers, 9)
heavy_norm = solve_layer_vmapped(T, p, f, k1, k2, k3, k4, k5, k6)
# Transpose to (9, nlayers)
heavy_norm = heavy_norm.T
# De-normalize by H2 to get absolute number ratios:
heavy = heavy_norm * Hmol
# H2, H, He
H2 = Hmol
H = Hatom
He = self.fHe * (2.0 * H2 + H)
# Stack all species in the same order as self.species:
all_species = jnp.vstack([heavy, H2[None, :], H[None, :], He[None, :]])
# Convert to VMR: normalize by total:
total = jnp.sum(all_species, axis=0, keepdims=True)
vmr_all = all_species / total
# Build dictionary: each species -> (nlayers,)
vmr_dict: Dict[str, jnp.ndarray] = {}
for i, name in enumerate(self.species):
vmr_dict[name] = vmr_all[i, :]
# Optionally return diagnostics
if return_diagnostics:
diagnostics = {
'n_layers': nlayers,
'T_range': (float(jnp.min(T)), float(jnp.max(T))),
'p_range': (float(jnp.min(p)), float(jnp.max(p))),
'T_mean': float(jnp.mean(T)),
'p_mean_log': float(jnp.mean(jnp.log10(p))),
}
return vmr_dict, diagnostics
return vmr_dict