Source code for exo_skryer.build_model

"""
build_model.py
==============
"""

from __future__ import annotations
from types import SimpleNamespace
from typing import Dict, Callable, Optional, Union

import jax
import jax.numpy as jnp
import numpy as np

from .data_constants import kb, amu, R_jup, R_sun, bar, G, M_jup

from .opacity_line import zero_line_opacity, compute_line_opacity
from .opacity_ck import zero_ck_opacity, compute_ck_opacity, compute_ck_opacity_perspecies
from .opacity_ray import zero_ray_opacity, compute_ray_opacity
from .opacity_cia import zero_cia_opacity, compute_cia_opacity
from .opacity_special import zero_special_opacity, compute_special_opacity
from .opacity_cloud import zero_cloud_opacity

from . import build_opacities as XS
from .build_chem import (
    prepare_chemistry_kernel,
    load_nasa9_if_needed,
    init_fastchem_grid_if_needed,
    init_element_potentials_if_needed,
    init_atmodeller_if_needed,
)
from .vert_chem import constant_vmr, constant_vmr_clr
from .vert_mu import build_compute_mu, constant_mu

from .RT_trans_1D_ck import compute_transit_depth_1d_ck
from .RT_trans_1D_ck_trans import compute_transit_depth_1d_ck_trans
from .RT_trans_1D_os import compute_transit_depth_1d_os
from .RT_em_1D_ck import compute_emission_spectrum_1d_ck
from .RT_em_1D_os import compute_emission_spectrum_1d_os
from .RT_em_schemes import get_emission_solver

from .instru_convolve import apply_response_functions_cached, get_bandpass_cache

from . import kernel_registry as KR

__all__ = [
    'build_forward_model'
]

# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------

def _extract_fixed_params(cfg) -> Dict[str, jnp.ndarray]:
    """Extract delta-distribution parameters and static cloud_dist into a fixed-params dict."""
    fixed_params: Dict[str, jnp.ndarray] = {}
    for param in cfg.params:
        if param.dist == "delta":
            raw_value = getattr(param, "value", None)
            if isinstance(raw_value, str):
                raw_lower = raw_value.strip().lower()
                if raw_lower in ("true", "false"):
                    raw_value = (raw_lower == "true")
                else:
                    raw_value = float(raw_value)
            fixed_params[param.name] = jnp.asarray(raw_value)

    cloud_dist_raw = getattr(cfg.physics, "cloud_dist", None)
    if cloud_dist_raw is not None:
        cloud_dist_str = str(cloud_dist_raw).lower().strip()
        if cloud_dist_str in ("1", "mono", "monodisperse"):
            fixed_params["cloud_dist"] = jnp.asarray(1, dtype=jnp.int32)
        elif cloud_dist_str in ("2", "log_normal", "lognormal", "log-normal", "ln"):
            fixed_params["cloud_dist"] = jnp.asarray(2, dtype=jnp.int32)
        else:
            raise ValueError("physics.cloud_dist must be 'mono' or 'log_normal' (or 1/2).")

    return fixed_params


def _resolve_os_ck_opac(phys, key: str, fn: Callable):
    """Resolve a simple none/os/ck opacity setting.

    Returns ``(scheme_str, kernel_or_None)``.  The scheme string is lowercased.
    """
    raw = getattr(phys, key, None)
    if raw is None:
        raise ValueError(
            f"physics.{key} must be specified explicitly (use 'None' to disable)."
        )
    s = str(raw).lower()
    if s == "none":
        print(f"[info] {key} is None:", raw)
        return s, None
    if s in ("os", "ck"):
        return s, fn
    raise NotImplementedError(
        f"Unknown physics.{key}='{raw}'. Options: none | os | ck"
    )


def _resolve_refraction(phys) -> int:
    """Return the refraction mode integer (0 = off, 1 = cutoff)."""
    refraction_raw = getattr(phys, "refraction", None)
    if refraction_raw is None:
        return 0
    s = str(refraction_raw).strip().lower()
    if s in ("none", "off", "false", "0"):
        return 0
    if s in ("cutoff", "refractive_cutoff", "refraction_cutoff"):
        return 1
    raise NotImplementedError(f"Unknown physics.refraction='{refraction_raw}'")


def _build_rt_kernel(phys, rt_scheme: str, ck: bool, ck_mix_str: str) -> Callable:
    """Select and return the radiative-transfer kernel for the given scheme."""
    if rt_scheme == "transit_1d":
        if ck:
            return compute_transit_depth_1d_ck_trans if ck_mix_str == "TRANS" else compute_transit_depth_1d_ck
        return compute_transit_depth_1d_os
    if rt_scheme == "emission_1d":
        em_scheme = getattr(phys, "em_scheme", "eaa")
        emission_solver = get_emission_solver(em_scheme)
        if ck:
            return lambda state, params, components, opac: compute_emission_spectrum_1d_ck(
                state, params, components, opac, emission_solver=emission_solver
            )
        return lambda state, params, components: compute_emission_spectrum_1d_os(
            state, params, components, emission_solver=emission_solver
        )
    raise NotImplementedError(
        f"Unknown physics.rt_scheme='{rt_scheme}'. Options: transit_1d | emission_1d"
    )


def _select_kernels(cfg) -> SimpleNamespace:
    """Select all physics and opacity kernels from the YAML config.

    Returns a :class:`~types.SimpleNamespace` with the following fields:

    Kernels
        ``Tp_kernel``, ``altitude_kernel``, ``chemistry_kernel``,
        ``mu_kernel``, ``vert_cloud_kernel``, ``line_opac_kernel``,
        ``ray_opac_kernel``, ``cia_opac_kernel``, ``cld_opac_kernel``,
        ``special_opac_kernel``, ``rt_kernel``

    Metadata used by validation / forward model
        ``ck``, ``rt_scheme``, ``ck_mix_str``, ``ck_mix_code_static``,
        ``contri_func_enabled``, ``refraction_mode``,
        ``line_opac_str``, ``ray_opac_str``, ``cia_opac_str``,
        ``cld_opac_str``, ``special_opac_str``
    """
    phys = cfg.physics

    # --- vertical structure ---
    vert_tp_raw = getattr(phys, "vert_Tp", None) or getattr(phys, "vert_struct", None)
    Tp_kernel = KR.resolve(vert_tp_raw, KR.VERT_TP, "physics.vert_Tp")

    altitude_kernel = KR.resolve(
        getattr(phys, "vert_alt", None), KR.VERT_ALT, "physics.vert_alt"
    )

    chemistry_kernel = KR.resolve(
        getattr(phys, "vert_chem", None), KR.VERT_CHEM, "physics.vert_chem"
    )

    mu_kernel = KR.resolve(
        getattr(phys, "vert_mu", None), KR.VERT_MU, "physics.vert_mu"
    )

    # vert_cloud defaults to "none" if absent from YAML
    vert_cloud_raw = getattr(phys, "vert_cloud", "none") or "none"
    vert_cloud_kernel = KR.resolve(vert_cloud_raw, KR.VERT_CLOUD, "physics.vert_cloud")

    # --- line opacity (also determines ck mode) ---
    line_opac_raw = getattr(phys, "opac_line", None)
    if line_opac_raw is None:
        raise ValueError(
            "physics.opac_line must be specified explicitly (use 'None' to disable)."
        )
    line_opac_str = str(line_opac_raw).lower()
    ck = (line_opac_str == "ck")
    if line_opac_str == "none":
        print(f"[info] Line opacity is None:", line_opac_raw)
        line_opac_kernel = None
    elif line_opac_str == "os":
        line_opac_kernel = compute_line_opacity
    elif line_opac_str == "ck":
        line_opac_kernel = compute_ck_opacity
    else:
        raise NotImplementedError(
            f"Unknown physics.opac_line='{line_opac_raw}'. Options: none | os | ck"
        )

    # --- continuum opacities ---
    ray_opac_str, ray_opac_kernel = _resolve_os_ck_opac(phys, "opac_ray", compute_ray_opacity)
    cia_opac_str, cia_opac_kernel = _resolve_os_ck_opac(phys, "opac_cia", compute_cia_opacity)

    # --- cloud opacity ---
    cld_opac_raw = getattr(phys, "opac_cloud", None)
    if cld_opac_raw is None:
        raise ValueError(
            "physics.opac_cloud must be specified explicitly (use 'None' to disable)."
        )
    cld_opac_str = str(cld_opac_raw).lower()
    if cld_opac_str == "none":
        print(f"[info] Cloud opacity is None:", cld_opac_raw)
    cld_opac_kernel = KR.resolve(cld_opac_raw, KR.OPAC_CLOUD, "physics.opac_cloud")

    # --- special opacity (H-) ---
    special_opac_str = str(getattr(phys, "opac_special", "on")).lower()
    special_opac_kernel = (
        None if special_opac_str in ("none", "off", "false", "0")
        else compute_special_opacity
    )

    # --- RT scheme ---
    rt_raw = getattr(phys, "rt_scheme", None)
    if rt_raw is None or str(rt_raw).lower() == "none":
        raise ValueError("physics.rt_scheme must be specified explicitly.")
    rt_scheme = str(rt_raw).lower()

    refraction_mode = _resolve_refraction(phys)

    ck_mix_str = str(getattr(cfg.opac, "ck_mix", "RORR")).upper()
    contri_func_enabled = bool(getattr(phys, "contri_func", False))

    ck_mix_code_static = None
    if ck:
        if ck_mix_str == "PRAS":
            ck_mix_code_static = 2
        elif ck_mix_str == "TRANS":
            ck_mix_code_static = 3
        else:
            ck_mix_code_static = 1

    rt_kernel = _build_rt_kernel(phys, rt_scheme, ck, ck_mix_str)

    return SimpleNamespace(
        Tp_kernel=Tp_kernel,
        altitude_kernel=altitude_kernel,
        chemistry_kernel=chemistry_kernel,
        mu_kernel=mu_kernel,
        vert_cloud_kernel=vert_cloud_kernel,
        line_opac_kernel=line_opac_kernel,
        ray_opac_kernel=ray_opac_kernel,
        cia_opac_kernel=cia_opac_kernel,
        cld_opac_kernel=cld_opac_kernel,
        special_opac_kernel=special_opac_kernel,
        rt_kernel=rt_kernel,
        ck=ck,
        rt_scheme=rt_scheme,
        ck_mix_str=ck_mix_str,
        ck_mix_code_static=ck_mix_code_static,
        contri_func_enabled=contri_func_enabled,
        refraction_mode=refraction_mode,
        line_opac_str=line_opac_str,
        ray_opac_str=ray_opac_str,
        cia_opac_str=cia_opac_str,
        cld_opac_str=cld_opac_str,
        special_opac_str=special_opac_str,
    )


def _build_opac_cache() -> Dict[str, jnp.ndarray]:
    """Assemble the runtime opacity cache dict from all loaded registries."""
    opac_cache: Dict[str, jnp.ndarray] = {}
    if XS.has_ck_data():
        opac_cache["ck_sigma_cube"] = XS.ck_sigma_cube()
        opac_cache["ck_log10_pressure_grid"] = XS.ck_log10_pressure_grid()
        opac_cache["ck_log10_temperature_grids"] = XS.ck_log10_temperature_grids()
        opac_cache["g_points"] = XS.ck_g_points_1d()
        opac_cache["g_weights"] = XS.ck_g_weights_1d()
    if XS.has_line_data():
        opac_cache["line_sigma_cube"] = XS.line_sigma_cube()
        opac_cache["line_log10_pressure_grid"] = XS.line_log10_pressure_grid()
        opac_cache["line_log10_temperature_grids"] = XS.line_log10_temperature_grids()
    if XS.has_cia_data():
        opac_cache["cia_master_wavelength"] = XS.cia_master_wavelength()
        opac_cache["cia_pair_species_i"] = XS.cia_pair_species_i()
        opac_cache["cia_pair_species_j"] = XS.cia_pair_species_j()
        opac_cache["cia_retained_sigma_cube"] = XS.cia_retained_sigma_cube()
        opac_cache["cia_retained_log10_temperature_grids"] = XS.cia_retained_log10_temperature_grids()
        opac_cache["cia_retained_temperature_grids"] = XS.cia_retained_temperature_grids()
    if XS.has_ray_data():
        opac_cache["ray_master_wavelength"] = XS.ray_master_wavelength()
        opac_cache["ray_sigma_linear_table"] = XS.ray_sigma_linear_table()
        opac_cache["ray_refractivity_coeff_table"] = XS.ray_refractivity_coeff_table()
    if XS.has_cloud_nk_data():
        opac_cache["cloud_nk_n"] = XS.cloud_nk_n()
        opac_cache["cloud_nk_k"] = XS.cloud_nk_k()
    if XS.has_special_data():
        opac_cache["hminus_master_wavelength"] = XS.special_master_wavelength()
        opac_cache["hminus_temperature_grid"] = XS.hminus_temperature_grid()
        opac_cache["hminus_log10_temperature_grid"] = XS.hminus_log10_temperature_grid()
        # Table presence depends on cfg.opac.special flags; only include those that exist.
        try:
            opac_cache["hminus_bf_log10_sigma"] = XS.hminus_bf_log10_sigma_table()
        except RuntimeError:
            pass
        try:
            opac_cache["hminus_ff_log10_sigma"] = XS.hminus_ff_log10_sigma_table()
        except RuntimeError:
            pass
    return opac_cache


def _require_cache_keys(opac_cache: dict, keys: tuple, label: str) -> None:
    missing = [k for k in keys if k not in opac_cache]
    if missing:
        raise RuntimeError(f"Missing {label} cache entries: {missing}")


def _validate_config(
    cfg,
    k: SimpleNamespace,
    opac_cache: Dict[str, jnp.ndarray],
) -> None:
    """Validate consistency of config, kernel selection, and loaded opacity data.

    Raises ``ValueError`` or ``RuntimeError`` on any inconsistency so that
    problems surface at build time rather than silently producing wrong results.
    """
    # Line opacity data present in registries
    if k.line_opac_str == "os" and not XS.has_line_data():
        raise RuntimeError(
            "Line opacity requested but registry is empty. "
            "Check cfg.opac.line and ensure build_opacities() loaded line tables."
        )
    if k.ck and not XS.has_ck_data():
        raise RuntimeError(
            "CK opacity requested but registry is empty. "
            "Check cfg.opac and ensure build_opacities() loaded ck tables."
        )

    # Refraction constraints
    if k.refraction_mode == 1:
        if k.rt_scheme != "transit_1d":
            raise NotImplementedError(
                "physics.refraction is only supported for rt_scheme: transit_1d."
            )
        param_names = {p.name for p in cfg.params}
        if "a_sm" not in param_names:
            raise ValueError(
                "physics.refraction: cutoff requires a delta parameter "
                "'a_sm' (semi-major axis in AU)."
            )
        if not XS.has_ray_data():
            raise RuntimeError(
                "physics.refraction: cutoff requires Rayleigh registry data (cfg.opac.ray)."
            )

    # CK-mix constraints
    if k.ck:
        if k.ck_mix_str == "TRANS" and k.rt_scheme != "transit_1d":
            raise NotImplementedError(
                "ck_mix: TRANS is only supported for rt_scheme: transit_1d."
            )
        if k.ck_mix_str == "TRANS" and k.contri_func_enabled:
            raise ValueError(
                "physics.contri_func=True is not supported with opac.ck_mix=TRANS. "
                "Use ck_mix=RORR/PRAS or disable contribution functions."
            )

    # Required opacity cache entries
    if k.ck and k.line_opac_str == "ck":
        _require_cache_keys(
            opac_cache,
            ("ck_sigma_cube", "ck_log10_pressure_grid", "ck_log10_temperature_grids",
             "g_points", "g_weights"),
            "correlated-k",
        )
    if (not k.ck) and k.line_opac_str == "os":
        _require_cache_keys(
            opac_cache,
            ("line_sigma_cube", "line_log10_pressure_grid", "line_log10_temperature_grids"),
            "opacity sampling",
        )
    if k.cia_opac_kernel is not None:
        _require_cache_keys(
            opac_cache,
            ("cia_master_wavelength", "cia_pair_species_i", "cia_pair_species_j",
             "cia_retained_sigma_cube", "cia_retained_log10_temperature_grids",
             "cia_retained_temperature_grids"),
            "CIA",
        )
    if k.ray_opac_kernel is not None:
        _require_cache_keys(opac_cache, ("ray_master_wavelength", "ray_sigma_linear_table"), "Rayleigh")
    if k.special_opac_kernel is not None and XS.has_special_data():
        _require_cache_keys(
            opac_cache,
            ("hminus_master_wavelength", "hminus_temperature_grid",
             "hminus_log10_temperature_grid"),
            "special-opacity",
        )

    # Mie cloud schemes require cached n,k
    if k.cld_opac_str in ("madt_rayleigh", "madt-rayleigh", "mie_madt",
                           "lxmie", "mie_full", "full_mie"):
        _require_cache_keys(opac_cache, ("cloud_nk_n", "cloud_nk_k"), "cloud n,k")


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] def build_forward_model( cfg, obs: Dict, stellar_flux: Optional[np.ndarray] = None, return_highres: bool = False, ) -> Callable[[Dict[str, jnp.ndarray]], Union[jnp.ndarray, Dict[str, jnp.ndarray]]]: """Build a JIT-compiled forward model for atmospheric retrieval. This function constructs a forward model by assembling physics kernels for vertical structure (temperature, chemistry, altitude), opacity sources (line, continuum, clouds), and radiative transfer. The returned function is JIT-compiled for efficient gradient-based inference. Parameters ---------- cfg : config object Configuration object containing physics settings (`cfg.physics`), opacity configuration (`cfg.opac`), and retrieval parameters (`cfg.params`). Must specify schemes for vertical structure (vert_Tp, vert_alt, vert_chem, vert_mu), opacity sources (opac_line, opac_ray, opac_cia, opac_cloud, opac_special), and radiative transfer (rt_scheme). obs : dict Observational data dictionary containing: - 'wl' : Observed wavelengths in microns (for bandpass loading) - 'dwl' : Wavelength bin widths in microns stellar_flux : `~numpy.ndarray`, optional Stellar flux array for emission spectroscopy calculations. Required when rt_scheme is 'emission_1d' and emission_mode is 'planet' (not brown dwarf). Should match the high-resolution wavelength grid. return_highres : bool, optional If True, the forward model returns both high-resolution and binned spectra as a dictionary: `{'hires': D_hires, 'binned': D_bin}`. If False (default), returns only the binned spectrum as a 1D array. Returns ------- forward_model : callable A JIT-compiled function with signature: `forward_model(params: Dict[str, jnp.ndarray]) -> Union[jnp.ndarray, Dict]` The function takes a parameter dictionary (free parameters from the retrieval) and returns: - If `return_highres=False`: 1D array of binned transit depth or emission flux - If `return_highres=True`: Dict with keys 'hires' (high-res spectrum) and 'binned' (convolved spectrum) """ rt_scheme_raw = getattr(getattr(cfg, "physics", None), "rt_scheme", None) if str(rt_scheme_raw).lower() == "transit_1_5d": from .build_model_1_5D import build_forward_model_1_5d return build_forward_model_1_5d( cfg, obs, stellar_flux=stellar_flux, return_highres=return_highres, ) fixed_params = _extract_fixed_params(cfg) nlay = int(getattr(cfg.physics, "nlay", 99)) nlev = nlay + 1 # Observational wavelengths/widths are consumed by response-function caches, # not directly by this function body. _ = np.asarray(obs["wl"], dtype=float) _ = np.asarray(obs["dwl"], dtype=float) # Select all physics and opacity kernels k = _select_kernels(cfg) # Assemble runtime opacity cache from loaded registries opac_cache = _build_opac_cache() # Validate consistency of the full configuration _validate_config(cfg, k, opac_cache) emission_mode = str(getattr(cfg.physics, "emission_mode", "planet")).lower().replace(" ", "_") if emission_mode is None: emission_mode = "planet" is_brown_dwarf = emission_mode in ("brown_dwarf", "browndwarf", "bd") # For planet emission, require stellar normalization information up front. # This avoids silent fallbacks inside JIT code paths. if k.rt_scheme == "emission_1d" and (not is_brown_dwarf) and (stellar_flux is None): param_names = {p.name for p in cfg.params} if "F_star" not in param_names: raise ValueError( "Planet emission mode requires either stellar_flux input or parameter 'F_star'." ) # High-resolution master grid (must match cut_grid used in bandpass loader) wl_hi_array = np.asarray(XS.master_wavelength_cut(), dtype=float) wl_hi = jnp.asarray(wl_hi_array) has_stellar_flux_arr = jnp.asarray(1 if stellar_flux is not None else 0, dtype=jnp.int32) if stellar_flux is not None: stellar_flux_arr = jnp.asarray(stellar_flux, dtype=jnp.float64) else: stellar_flux_arr = jnp.zeros_like(wl_hi, dtype=jnp.float64) bandpass_cache = get_bandpass_cache() # Ensure chemistry backends that require pre-built caches are initialized # when build_forward_model is used directly by analysis scripts. load_nasa9_if_needed(cfg, None) init_fastchem_grid_if_needed(cfg, None) init_element_potentials_if_needed(cfg, None) init_atmodeller_if_needed(cfg, None) chemistry_kernel, trace_species = prepare_chemistry_kernel( cfg, k.chemistry_kernel, { 'line_opac': k.line_opac_str, 'ray_opac': k.ray_opac_str, 'cia_opac': k.cia_opac_str, 'special_opac': k.special_opac_str, } ) if k.chemistry_kernel in (constant_vmr, constant_vmr_clr): cfg_param_names = {str(getattr(p, "name", "")) for p in getattr(cfg, "params", [])} include_atomic_h = "log_10_H_over_H2" in cfg_param_names packed_mu_species = tuple( dict.fromkeys((*trace_species, "H2", "He", *(("H",) if include_atomic_h else ()))) ) compute_mu_fast = build_compute_mu(packed_mu_species) mu_mode = str(getattr(cfg.physics, "vert_mu", "auto")).lower() if mu_mode == "auto": def mu_kernel(params, vmr_lay, nlay, _compute_mu_fast=compute_mu_fast): if "mu" in params: return constant_mu(params, nlay) if "__mu_lay__" in vmr_lay: return vmr_lay["__mu_lay__"] return _compute_mu_fast(vmr_lay) elif mu_mode in ("dynamic", "variable", "vmr"): def mu_kernel(params, vmr_lay, nlay, _compute_mu_fast=compute_mu_fast): if "__mu_lay__" in vmr_lay: return vmr_lay["__mu_lay__"] return _compute_mu_fast(vmr_lay) # Capture kernel selections into local names for the JIT closure Tp_kernel = k.Tp_kernel altitude_kernel = k.altitude_kernel mu_kernel = k.mu_kernel vert_cloud_kernel = k.vert_cloud_kernel line_opac_kernel = k.line_opac_kernel ray_opac_kernel = k.ray_opac_kernel cia_opac_kernel = k.cia_opac_kernel cld_opac_kernel = k.cld_opac_kernel special_opac_kernel = k.special_opac_kernel rt_kernel = k.rt_kernel ck = k.ck rt_scheme = k.rt_scheme ck_mix_code_static = k.ck_mix_code_static contri_func_enabled = k.contri_func_enabled refraction_mode = k.refraction_mode @jax.jit def _forward_model_impl( params: Dict[str, jnp.ndarray], wl_runtime: jnp.ndarray, stellar_flux_runtime: jnp.ndarray, has_stellar_flux_runtime: jnp.ndarray, opac_cache_runtime: Dict[str, jnp.ndarray], bandpass_cache_runtime: Dict[str, jnp.ndarray], ) -> jnp.ndarray: # Merge fixed (delta) parameters with varying parameters full_params = {**fixed_params, **params} wl = wl_runtime # Dimension constants nwl = wl.shape[0] # Calculate log_10_g from mass and radius if M_p is provided if "M_p" in full_params: M_p = full_params["M_p"] * M_jup # Convert to g g = (10.0 ** full_params["log_10_g"]) # cm/s^2 R0 = jnp.sqrt((G * M_p) / g) # Radius in cm else: R0 = full_params["R_p"] * R_jup R_s = full_params["R_s"] * R_sun # Atmospheric pressure grid p_bot = full_params["p_bot"] * bar p_top = full_params["p_top"] * bar p_lev = jnp.logspace(jnp.log10(p_bot), jnp.log10(p_top), nlev) # Vertical atmospheric T-p layer structure p_lay = (p_lev[1:] - p_lev[:-1]) / jnp.log(p_lev[1:]/p_lev[:-1]) T_lev, T_lay = Tp_kernel(p_lev, full_params) # Get the vertical chemical structure (VMRs at each layer) vmr_lay = chemistry_kernel(p_lay, T_lay, full_params, nlay) # Mean molecular weight calculation mu_lay = mu_kernel(full_params, vmr_lay, nlay) # Vertical altitude calculation z_lev, z_lay, dz = altitude_kernel(p_lev, T_lay, mu_lay, full_params) # Atmospheric density and number density rho_lay = (mu_lay * amu * p_lay) / (kb * T_lay) nd_lay = p_lay / (kb * T_lay) # Cloud vertical profile (mass mixing ratio) q_c_lay = vert_cloud_kernel(p_lay, T_lay, mu_lay, rho_lay, nd_lay, full_params) # Opacity cache for kernels (separate from atmospheric state) opac = opac_cache_runtime state = { "nwl": nwl, "nlay": nlay, "wl": wl, "is_brown_dwarf": is_brown_dwarf, "R0": R0, "R_s": R_s, "p_lev": p_lev, "T_lev": T_lev, "z_lev": z_lev, "z_lay": z_lay, "dz": dz, "mu_lay": mu_lay, "T_lay": T_lay, "p_lay": p_lay, "rho_lay": rho_lay, "nd_lay": nd_lay, "q_c_lay": q_c_lay, "vmr_lay": vmr_lay, "contri_func": contri_func_enabled, "refraction_mode": refraction_mode, } if "cloud_nk_n" in opac: state["cloud_nk_n"] = opac["cloud_nk_n"] state["cloud_nk_k"] = opac["cloud_nk_k"] state["stellar_flux"] = stellar_flux_runtime state["has_stellar_flux"] = has_stellar_flux_runtime if ck_mix_code_static is not None: state["ck_mix"] = ck_mix_code_static if ck: line_zero = zero_ck_opacity(state, opac, full_params) else: line_zero = zero_line_opacity(state, full_params) k_cld_zero, cld_ssa_zero, cld_g_zero = zero_cloud_opacity(state, full_params) opacity_components = { "line": line_zero, "rayleigh": zero_ray_opacity(state, full_params), "cia": zero_cia_opacity(state, full_params), "special": zero_special_opacity(state, full_params), "cloud": k_cld_zero, "cloud_ssa": cld_ssa_zero, "cloud_g": cld_g_zero, } if line_opac_kernel is not None: # For TRANS method, compute per-species opacities (mixing happens in RT) if ck and ck_mix_code_static == 3: # TRANS sigma_ps, vmr_ps = compute_ck_opacity_perspecies(state, opac, full_params) opacity_components["line_perspecies"] = sigma_ps opacity_components["vmr_perspecies"] = vmr_ps else: opacity_components["line"] = line_opac_kernel(state, opac, full_params) if ray_opac_kernel is not None: opacity_components["rayleigh"] = ray_opac_kernel(state, opac, full_params) if cia_opac_kernel is not None: opacity_components["cia"] = cia_opac_kernel(state, opac, full_params) if special_opac_kernel is not None: opacity_components["special"] = special_opac_kernel(state, opac, full_params) if cld_opac_kernel is not None: k_cld_ext, cld_ssa, cld_g = cld_opac_kernel(state, full_params) opacity_components["cloud"] = k_cld_ext opacity_components["cloud_ssa"] = cld_ssa opacity_components["cloud_g"] = cld_g # Radiative transfer # RT kernels always return (spectrum, contrib_func) # contrib_func is zeros if state["contri_func"] is False if rt_scheme == "transit_1d": # All transit RT kernels accept the opac cache (OS may use it for refraction). D_hires, contrib_func = rt_kernel(state, full_params, opacity_components, opac) else: if ck: D_hires, contrib_func = rt_kernel(state, full_params, opacity_components, opac) else: D_hires, contrib_func = rt_kernel(state, full_params, opacity_components) # Instrumental convolution → binned spectrum D_bin = apply_response_functions_cached(D_hires, bandpass_cache_runtime) if return_highres: result_dict = {"hires": D_hires, "binned": D_bin} if state["contri_func"]: result_dict["contrib_func"] = contrib_func result_dict["p_lay"] = p_lay return result_dict return D_bin def forward_model(params: Dict[str, jnp.ndarray]) -> jnp.ndarray: return _forward_model_impl( params, wl_hi, stellar_flux_arr, has_stellar_flux_arr, opac_cache, bandpass_cache, ) return forward_model