"""
opacity_line.py
===============
"""
from typing import Dict
import jax.numpy as jnp
from jax import lax
from . import build_opacities as XS
from .data_constants import amu, bar
__all__ = [
"zero_line_opacity",
"compute_line_opacity"
]
[docs]
def zero_line_opacity(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Return a zero opacity-sampling opacity array.
This function is used as a fallback when opacity-sampling opacities are disabled
in the configuration. It maintains API compatibility with `compute_line_opacity()`
so the forward model can seamlessly switch between OS enabled/disabled.
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
State dictionary containing:
- `nlay` : int
Number of atmospheric layers.
- `nwl` : int
Number of wavelength points.
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary (unused; kept for API compatibility).
Returns
-------
zeros : `~jax.numpy.ndarray`, shape (nlay, nwl)
Zero-valued line opacity array in cm² g⁻¹.
"""
# Use shape directly without jnp.size() for JIT compatibility
shape = (state["nlay"], state["nwl"])
return jnp.zeros(shape)
[docs]
def compute_line_opacity(state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Compute opacity-sampling mass opacity for all molecular/atomic absorbers.
This function calculates the total line absorption opacity by:
1. Interpolating pre-loaded cross-sections to atmospheric (P, T) conditions
2. Weighting each species' opacity by its volume mixing ratio
3. Summing contributions from all species
4. Converting from molecular cross-section to mass opacity
Parameters
----------
state : dict[str, `~jax.numpy.ndarray`]
Atmospheric state dictionary containing:
- `p_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Layer pressures in dyne cm⁻².
- `T_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Layer temperatures in Kelvin.
- `mu_lay` : `~jax.numpy.ndarray`, shape (nlay,)
Mean molecular weight per layer in g mol^-1.
- `vmr_lay` : dict[str, `~jax.numpy.ndarray`]
Volume mixing ratios for each species. Keys must match species
names in the loaded line opacity tables. Values can be scalars
or arrays with shape (nlay,).
params : dict[str, `~jax.numpy.ndarray`]
Parameter dictionary (unused; VMRs come from state['vmr_lay']).
Kept for API compatibility with other opacity functions.
Returns
-------
kappa_line : `~jax.numpy.ndarray`, shape (nlay, nwl)
Total line absorption mass opacity in cm² g⁻¹ at each layer and
wavelength point.
"""
layer_pressures = state["p_lay"]
layer_temperatures = state["T_lay"]
layer_mu = state["mu_lay"]
layer_vmr = state["vmr_lay"]
# Get species names and mixing ratios
species_names = XS.line_runtime_species_order()
layer_count = layer_pressures.shape[0]
sigma_cube = opac["line_sigma_cube"]
log_p_grid = opac["line_log10_pressure_grid"]
log_temperature_grids = opac["line_log10_temperature_grids"]
# Direct lookup - species names must match VMR keys exactly
# VMR values are already JAX arrays, no need to wrap
mixing_ratios = jnp.stack(
[jnp.broadcast_to(layer_vmr[name], (layer_count,)) for name in species_names],
axis=0,
)
layer_pressures_bar = layer_pressures / bar
log_p_layers = jnp.log10(layer_pressures_bar)
log_t_layers = jnp.log10(layer_temperatures)
p_idx = jnp.searchsorted(log_p_grid, log_p_layers) - 1
p_idx = jnp.clip(p_idx, 0, log_p_grid.shape[0] - 2)
p_weight = (log_p_layers - log_p_grid[p_idx]) / (log_p_grid[p_idx + 1] - log_p_grid[p_idx])
p_weight = jnp.clip(p_weight, 0.0, 1.0)
n_species = sigma_cube.shape[0]
n_wl = sigma_cube.shape[-1]
def _accumulate_one(i: jnp.ndarray, acc: jnp.ndarray) -> jnp.ndarray:
sigma_3d = sigma_cube[i] # (nT, nP, nwl) in log10(cm^2)
log_temp_grid = log_temperature_grids[i] # (nT,)
t_idx = jnp.searchsorted(log_temp_grid, log_t_layers) - 1
t_idx = jnp.clip(t_idx, 0, log_temp_grid.shape[0] - 2)
t_weight = (log_t_layers - log_temp_grid[t_idx]) / (log_temp_grid[t_idx + 1] - log_temp_grid[t_idx])
t_weight = jnp.clip(t_weight, 0.0, 1.0)
s_t0_p0 = sigma_3d[t_idx, p_idx, :] # (nlay, nwl)
s_t0_p1 = sigma_3d[t_idx, p_idx + 1, :]
s_t1_p0 = sigma_3d[t_idx + 1, p_idx, :]
s_t1_p1 = sigma_3d[t_idx + 1, p_idx + 1, :]
s_t0 = (1.0 - p_weight)[:, None] * s_t0_p0 + p_weight[:, None] * s_t0_p1
s_t1 = (1.0 - p_weight)[:, None] * s_t1_p0 + p_weight[:, None] * s_t1_p1
s_interp = (1.0 - t_weight)[:, None] * s_t0 + t_weight[:, None] * s_t1 # log10 sigma
sigma_linear = 10.0 ** s_interp.astype(jnp.float64) # (nlay, nwl) in cm^2
return acc + sigma_linear * mixing_ratios[i, :, None]
weighted_sigma = lax.fori_loop(
0,
n_species,
_accumulate_one,
jnp.zeros((layer_count, n_wl), dtype=jnp.float64),
)
return weighted_sigma / (layer_mu[:, None] * amu)