Source code for exo_skryer.vert_mu
"""
vert_mu.py
==========
"""
from __future__ import annotations
from typing import Callable, Dict
import jax.numpy as jnp
from .data_constants import CHEM_SPECIES_DATA
_SPECIES_MASS = {entry["symbol"]: float(entry["molecular_weight"]) for entry in CHEM_SPECIES_DATA}
__all__ = [
"constant_mu",
"compute_mu",
"build_compute_mu",
]
[docs]
def constant_mu(params: Dict[str, jnp.ndarray], nlay: int) -> jnp.ndarray:
"""Return a constant mean molecular weight (μ) profile.
Parameters
----------
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary containing:
- `mu` : float
Mean molecular weight in g mol⁻¹
nlay : int
Number of atmospheric layers.
Returns
-------
mu_lay : `~jax.numpy.ndarray`, shape (nlay,)
Mean molecular weight profile in g mol⁻¹.
"""
if "mu" not in params:
raise ValueError("vert_mu='constant' requires a 'mu' parameter.")
mu_const = params["mu"]
return jnp.full((nlay,), mu_const)
[docs]
def compute_mu(vmr_lay: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute mean molecular weight from volume mixing ratios.
Parameters
----------
vmr_lay : dict[str, `~jax.numpy.ndarray`]
Dictionary mapping species symbols to their VMR profiles. Each value
should be an array of shape (nlay,).
Returns
-------
mu_lay : `~jax.numpy.ndarray`, shape (nlay,)
Mean molecular weight profile in g mol⁻¹.
"""
# Electrons (e-) have negligible mass and should not affect mean molecular weight.
species_list = sorted(
species for species in vmr_lay.keys() if species in _SPECIES_MASS and species != "e-"
)
if not species_list:
raise ValueError("No valid species provided to compute mean molecular weight.")
vmr_arrays = [vmr_lay[sp] for sp in species_list]
masses = jnp.array([_SPECIES_MASS[sp] for sp in species_list])
vmr_stack = jnp.stack(vmr_arrays, axis=0)
mu_profile = jnp.sum(vmr_stack * masses[:, None], axis=0)
return mu_profile
[docs]
def build_compute_mu(species_order: tuple[str, ...]) -> Callable[[Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Build a mean-molecular-weight kernel with a fixed species ordering."""
valid_species = tuple(
species for species in species_order
if species in _SPECIES_MASS and species != "e-"
)
if not valid_species:
raise ValueError("No valid non-electron species were provided for mean molecular weight.")
masses = jnp.asarray([_SPECIES_MASS[species] for species in valid_species])
def _compute_mu_fixed(vmr_lay: Dict[str, jnp.ndarray]) -> jnp.ndarray:
vmr_stack = jnp.stack([vmr_lay[species] for species in valid_species], axis=0)
return jnp.sum(vmr_stack * masses[:, None], axis=0)
return _compute_mu_fixed