Source code for exo_skryer.RT_em_1D_os

"""
RT_em_1D_os.py
==============
"""

from __future__ import annotations

from typing import Dict, Mapping, Tuple

import jax.numpy as jnp

from .data_constants import kb, h, c_light, pc
from .RT_em_schemes import solve_alpha_eaa

__all__ = ["compute_emission_spectrum_1d_os"]


def _sum_opacity_components_os(
    state: Dict[str, jnp.ndarray],
    opacity_components: Mapping[str, jnp.ndarray],
) -> jnp.ndarray:
    nlay = state["rho_lay"].shape[0]
    nwl = state["wl"].shape[0]

    if not opacity_components:
        return jnp.zeros((nlay, nwl))

    component_keys = ("line", "rayleigh", "cia", "special", "cloud")
    first = next((opacity_components.get(k) for k in component_keys if k in opacity_components), None)
    if first is None:
        return jnp.zeros((nlay, nwl))

    zeros = jnp.zeros_like(first)
    stacked = jnp.stack([opacity_components.get(k, zeros) for k in component_keys], axis=0)
    return jnp.sum(stacked, axis=0)


def _planck_lambda(wavelength_cm: jnp.ndarray, temperature: jnp.ndarray) -> jnp.ndarray:
    exponent = (h * c_light) / (wavelength_cm * kb * jnp.maximum(temperature, 1.0))
    expm1 = jnp.expm1(jnp.clip(exponent, None, 80.0))
    prefactor = 2.0 * h * c_light**2 / (wavelength_cm**5)
    return prefactor / jnp.maximum(expm1, 1e-300)


def _layer_optical_depth_os(k_tot: jnp.ndarray, rho: jnp.ndarray, dz: jnp.ndarray) -> jnp.ndarray:
    return k_tot * rho[:, None] * dz[:, None]


def _compute_scattering_properties(
    opacity_components: Mapping[str, jnp.ndarray],
    state: Dict[str, jnp.ndarray],
    k_tot: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    nlay, nwl = k_tot.shape

    def _get_component(name, shape):
        arr = opacity_components.get(name)
        if arr is None:
            return jnp.zeros(shape)
        return arr

    base_shape = (nlay, nwl)
    k_ray = _get_component("rayleigh", base_shape)
    k_cloud_ext = _get_component("cloud", base_shape)
    cloud_ssa = _get_component("cloud_ssa", base_shape)
    cloud_g = _get_component("cloud_g", base_shape)

    k_cloud_scat = cloud_ssa * k_cloud_ext
    k_tot_scat = k_ray + k_cloud_scat
    k_tot_safe = jnp.maximum(k_tot, 1.0e-30)

    ssa = jnp.clip(k_tot_scat / k_tot_safe, 0.0, 0.99)
    g = cloud_g
    return ssa, g


def _scale_flux_ratio(
    flux: jnp.ndarray,
    state: Dict[str, jnp.ndarray],
    params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
    dtype = flux.dtype
    stellar_flux = jnp.asarray(state.get("stellar_flux", jnp.ones_like(flux)), dtype=dtype)
    has_stellar_flux = jnp.asarray(state.get("has_stellar_flux", 0), dtype=jnp.int32)
    f_star_param = jnp.asarray(params.get("F_star", jnp.ones_like(stellar_flux)), dtype=dtype)
    use_stellar = has_stellar_flux != 0
    F_star = jnp.where(use_stellar, stellar_flux, f_star_param)
    R0 = jnp.asarray(state["R0"], dtype=dtype)
    R_s = jnp.asarray(state["R_s"], dtype=dtype)
    scale = (R0**2) / (F_star * (R_s**2))
    return flux * scale


[docs] def compute_emission_spectrum_1d_os( state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray], opacity_components: Mapping[str, jnp.ndarray], emission_solver=solve_alpha_eaa, ) -> Tuple[jnp.ndarray, jnp.ndarray]: contri_func = state.get("contri_func", False) dtype = state["wl"].dtype wl_cm = jnp.asarray(state["wl"], dtype=dtype) * 1.0e-4 T_lev = jnp.asarray(state["T_lev"], dtype=dtype) rho_lay = jnp.asarray(state["rho_lay"], dtype=dtype) dz = jnp.asarray(state["dz"], dtype=dtype) be_levels = _planck_lambda(wl_cm[None, :], T_lev[:, None]) if "T_int" in params: T_int = jnp.asarray(params["T_int"], dtype=dtype) be_internal = _planck_lambda(wl_cm[None, :], T_int[None, None])[0] else: be_internal = jnp.zeros_like(be_levels[-1]) def _lw_up_for_components(components: Mapping[str, jnp.ndarray], k_tot_local: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: ssa_os, g_os = _compute_scattering_properties( components, state, k_tot_local, ) dtau_os = _layer_optical_depth_os(k_tot_local, rho_lay, dz) lw_up_out, _, layer_contrib_out = emission_solver( be_levels, dtau_os, ssa_os, g_os, be_internal, return_layer_contrib=contri_func ) return lw_up_out, layer_contrib_out k_tot_cloud = _sum_opacity_components_os(state, opacity_components) lw_up_cloud, layer_contrib_cloud = _lw_up_for_components(opacity_components, k_tot_cloud) if "f_cloud" in params and "cloud" in opacity_components: f_cloud = jnp.clip(params["f_cloud"], 0.0, 1.0) k_tot_clear = k_tot_cloud - opacity_components["cloud"] zeros = jnp.zeros_like(opacity_components["cloud"]) opacity_clear = dict(opacity_components) opacity_clear["cloud"] = zeros opacity_clear["cloud_ssa"] = zeros opacity_clear["cloud_g"] = zeros lw_up_clear, layer_contrib_clear = _lw_up_for_components(opacity_clear, k_tot_clear) lw_up = f_cloud * lw_up_cloud + (1.0 - f_cloud) * lw_up_clear layer_contrib_flux = f_cloud * layer_contrib_cloud + (1.0 - f_cloud) * layer_contrib_clear else: lw_up = lw_up_cloud layer_contrib_flux = layer_contrib_cloud s_dilute = jnp.asarray(params.get("s_dilute", 1.0)) top_flux = s_dilute * lw_up[0] if state.get("is_brown_dwarf", False): R0 = jnp.asarray(state["R0"], dtype=top_flux.dtype) D = params["D"] f_p = 10.0 ** jnp.asarray(params.get("log_10_f_p", 0.0), dtype=top_flux.dtype) distance = D * pc final_spectrum = top_flux * f_p * (R0 / distance) ** 2 else: final_spectrum = _scale_flux_ratio(top_flux, state, params) if contri_func: layer_contrib = jnp.clip(layer_contrib_flux, 0.0) contrib_func_norm = layer_contrib / jnp.maximum(layer_contrib.sum(axis=0, keepdims=True), 1e-30) else: contrib_func_norm = jnp.zeros((state["dz"].shape[0], final_spectrum.shape[0])) return final_spectrum, contrib_func_norm