Source code for exo_skryer.opacity_ray

"""
opacity_ray.py
==============
"""

from __future__ import annotations
from typing import Dict

import jax.numpy as jnp
from . import registry_ray as XR

__all__ = [
    "zero_ray_opacity",
    "compute_ray_opacity"
]


[docs] def zero_ray_opacity(state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Return a zero Rayleigh scattering opacity array. Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] State dictionary containing: - `nlay` : int Number of atmospheric layers. - `nwl` : int Number of wavelength points. params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; kept for API compatibility). Returns ------- zeros : `~jax.numpy.ndarray`, shape (nlay, nwl) Zero-valued Rayleigh opacity array in cm² g⁻¹. """ # Use shape directly without int() conversion for JIT compatibility shape = (state["nlay"], state["nwl"]) return jnp.zeros(shape)
[docs] def compute_ray_opacity(state: Dict[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray]) -> jnp.ndarray: """Compute Rayleigh scattering mass opacity for the configured scatterers. This function converts precomputed Rayleigh scattering cross-sections from `exo_skryer.registry_ray` into a layer-by-wavelength mass opacity in cm² g⁻¹. If no Rayleigh data are loaded, it returns zeros with the expected shape. Parameters ---------- state : dict[str, `~jax.numpy.ndarray`] Atmospheric state dictionary containing: - `wl` : `~jax.numpy.ndarray`, shape `(nwl,)` Forward-model wavelength grid in microns. - `nd_lay` : `~jax.numpy.ndarray`, shape `(nlay,)` Layer total number density in cm⁻³. - `rho_lay` : `~jax.numpy.ndarray`, shape `(nlay,)` Layer mass density in g cm⁻³. - `vmr_lay` : dict[str, `~jax.numpy.ndarray`] Volume mixing ratios for each Rayleigh species. Keys must match `registry_ray.ray_species_names()`. Values may be scalars or arrays with shape (nlay,). params : dict[str, `~jax.numpy.ndarray`] Parameter dictionary (unused; kept for API compatibility). Returns ------- kappa_ray : `~jax.numpy.ndarray`, shape (nlay, nwl) Rayleigh scattering mass opacity in cm² g⁻¹ at each layer and wavelength. """ wavelengths = state["wl"] number_density = state["nd_lay"] density = state["rho_lay"] layer_vmr = state["vmr_lay"] layer_count = number_density.shape[0] master_wavelength = opac["ray_master_wavelength"] if master_wavelength.shape != wavelengths.shape: raise ValueError("Rayleigh wavelength grid must match forward-model grid.") sigma_values = opac["ray_sigma_linear_table"] # (n_species, nwl) species_names = XR.ray_runtime_species_order() mixing_ratios = jnp.stack( [jnp.broadcast_to(layer_vmr[name], (layer_count,)) for name in species_names], axis=0, ) # (n_species, nlay) sigma_weighted = jnp.einsum("sl,sw->lw", mixing_ratios, sigma_values) return (number_density[:, None] * sigma_weighted) / density[:, None]