Source code for exo_skryer.RT_em_1D_ck

"""
RT_em_1D_ck.py
==============
"""

from __future__ import annotations

from typing import Dict, Mapping, Tuple

import jax.numpy as jnp
from jax import lax

from .data_constants import kb, h, c_light, pc
from .RT_em_schemes import solve_alpha_eaa

__all__ = ["compute_emission_spectrum_1d_ck"]


def _get_ck_weights(opac: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    g_weights = opac.get("g_weights")
    if g_weights is None:
        raise RuntimeError("Missing opac['g_weights'] for c-k integration.")
    return g_weights


def _sum_opacity_components_ck(
    state: Dict[str, jnp.ndarray],
    opacity_components: Mapping[str, jnp.ndarray],
    opac: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
    nlay = state["rho_lay"].shape[0]
    nwl = state["wl"].shape[0]

    if not opacity_components:
        g_weights = _get_ck_weights(opac)
        ng = g_weights.shape[-1]
        return jnp.zeros((nlay, nwl, ng))

    line_opacity = opacity_components.get("line")
    if line_opacity is None:
        g_weights = _get_ck_weights(opac)
        ng = g_weights.shape[-1]
        line_opacity = jnp.zeros((nlay, nwl, ng))

    zeros_2d = jnp.zeros((nlay, nwl), dtype=line_opacity.dtype)
    component_keys_2d = ("rayleigh", "cia", "special", "cloud")
    components_2d = jnp.stack([opacity_components.get(k, zeros_2d) for k in component_keys_2d], axis=0)
    summed_2d = jnp.sum(components_2d, axis=0)

    return line_opacity + summed_2d[:, :, None]


def _planck_lambda(wavelength_cm: jnp.ndarray, temperature: jnp.ndarray) -> jnp.ndarray:
    exponent = (h * c_light) / (wavelength_cm * kb * jnp.maximum(temperature, 1.0))
    expm1 = jnp.expm1(jnp.clip(exponent, None, 80.0))
    prefactor = 2.0 * h * c_light**2 / (wavelength_cm**5)
    return prefactor / jnp.maximum(expm1, 1e-300)


def _layer_optical_depth_ck(k_tot: jnp.ndarray, rho: jnp.ndarray, dz: jnp.ndarray) -> jnp.ndarray:
    return k_tot * rho[:, None, None] * dz[:, None, None]


def _compute_scattering_properties(
    opacity_components: Mapping[str, jnp.ndarray],
    state: Dict[str, jnp.ndarray],
    k_tot: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    nlay, nwl = k_tot.shape[:2]

    def _get_component(name, shape):
        arr = opacity_components.get(name)
        if arr is None:
            return jnp.zeros(shape)
        return arr

    base_shape = (nlay, nwl)
    k_ray = _get_component("rayleigh", base_shape)
    k_cloud_ext = _get_component("cloud", base_shape)
    cloud_ssa = _get_component("cloud_ssa", base_shape)
    cloud_g = _get_component("cloud_g", base_shape)

    k_cloud_scat = cloud_ssa * k_cloud_ext
    k_tot_scat = k_ray + k_cloud_scat
    k_tot_safe = jnp.maximum(k_tot, 1.0e-30)

    k_tot_scat = k_tot_scat[:, :, None]
    cloud_g = jnp.broadcast_to(cloud_g[:, :, None], k_tot.shape)

    ssa = jnp.clip(k_tot_scat / k_tot_safe, 0.0, 0.99)
    g = cloud_g
    return ssa, g


def _scale_flux_ratio(
    flux: jnp.ndarray,
    state: Dict[str, jnp.ndarray],
    params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
    dtype = flux.dtype
    stellar_flux = jnp.asarray(state.get("stellar_flux", jnp.ones_like(flux)), dtype=dtype)
    has_stellar_flux = jnp.asarray(state.get("has_stellar_flux", 0), dtype=jnp.int32)
    f_star_param = jnp.asarray(params.get("F_star", jnp.ones_like(stellar_flux)), dtype=dtype)
    use_stellar = has_stellar_flux != 0
    F_star = jnp.where(use_stellar, stellar_flux, f_star_param)
    R0 = jnp.asarray(state["R0"], dtype=dtype)
    R_s = jnp.asarray(state["R_s"], dtype=dtype)
    scale = (R0**2) / (F_star * (R_s**2))
    return flux * scale


[docs] def compute_emission_spectrum_1d_ck( state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray], opacity_components: Mapping[str, jnp.ndarray], opac: Dict[str, jnp.ndarray], emission_solver=solve_alpha_eaa, ) -> Tuple[jnp.ndarray, jnp.ndarray]: contri_func = state.get("contri_func", False) dtype = state["wl"].dtype wl_cm = jnp.asarray(state["wl"], dtype=dtype) * 1.0e-4 T_lev = jnp.asarray(state["T_lev"], dtype=dtype) rho_lay = jnp.asarray(state["rho_lay"], dtype=dtype) dz = jnp.asarray(state["dz"], dtype=dtype) be_levels = _planck_lambda(wl_cm[None, :], T_lev[:, None]) if "T_int" in params: T_int = jnp.asarray(params["T_int"], dtype=dtype) be_internal = _planck_lambda(wl_cm[None, :], T_int[None, None])[0] else: be_internal = jnp.zeros_like(be_levels[-1]) if contri_func: def _lw_up_for_components(components: Mapping[str, jnp.ndarray], k_tot_local: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: ssa_ck, g_ck = _compute_scattering_properties( components, state, k_tot_local, ) dtau_ck = _layer_optical_depth_ck(k_tot_local, rho_lay, dz) g_weights = _get_ck_weights(opac) dtau_by_g = jnp.moveaxis(dtau_ck, -1, 0) ssa_by_g = jnp.moveaxis(ssa_ck, -1, 0) g_by_g = jnp.moveaxis(g_ck, -1, 0) g_weights = g_weights[: dtau_by_g.shape[0]] def _scan_body(carry, inputs): lw_up_accum, contrib_accum = carry dtau_slice, ssa_slice, g_slice, weight = inputs lw_up_g, _, layer_contrib_g = emission_solver( be_levels, dtau_slice, ssa_slice, g_slice, be_internal, return_layer_contrib=True ) weight_lw = weight.astype(lw_up_accum.dtype) weight_cf = weight.astype(contrib_accum.dtype) lw_up_accum = lw_up_accum + weight_lw * lw_up_g contrib_accum = contrib_accum + weight_cf * layer_contrib_g return (lw_up_accum, contrib_accum), None lw_up_init = jnp.zeros_like(be_levels) contrib_init = jnp.zeros((state["dz"].shape[0], state["wl"].shape[0]), dtype=be_levels.dtype) (lw_up_out, contrib_out), _ = lax.scan( _scan_body, (lw_up_init, contrib_init), (dtau_by_g, ssa_by_g, g_by_g, g_weights) ) return lw_up_out, contrib_out else: def _lw_up_for_components(components: Mapping[str, jnp.ndarray], k_tot_local: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: ssa_ck, g_ck = _compute_scattering_properties( components, state, k_tot_local, ) dtau_ck = _layer_optical_depth_ck(k_tot_local, rho_lay, dz) g_weights = _get_ck_weights(opac) dtau_by_g = jnp.moveaxis(dtau_ck, -1, 0) ssa_by_g = jnp.moveaxis(ssa_ck, -1, 0) g_by_g = jnp.moveaxis(g_ck, -1, 0) g_weights = g_weights[: dtau_by_g.shape[0]] def _scan_body(lw_up_accum, inputs): dtau_slice, ssa_slice, g_slice, weight = inputs lw_up_g, _, _ = emission_solver( be_levels, dtau_slice, ssa_slice, g_slice, be_internal, return_layer_contrib=False ) return lw_up_accum + weight.astype(lw_up_accum.dtype) * lw_up_g, None lw_up_init = jnp.zeros_like(be_levels) lw_up_out, _ = lax.scan(_scan_body, lw_up_init, (dtau_by_g, ssa_by_g, g_by_g, g_weights)) contrib_out = jnp.zeros((state["dz"].shape[0], state["wl"].shape[0]), dtype=be_levels.dtype) return lw_up_out, contrib_out k_tot_cloud = _sum_opacity_components_ck(state, opacity_components, opac) lw_up_cloud, layer_contrib_cloud = _lw_up_for_components(opacity_components, k_tot_cloud) if "f_cloud" in params and "cloud" in opacity_components: f_cloud = jnp.clip(params["f_cloud"], 0.0, 1.0) cloud_ext = opacity_components["cloud"] k_tot_clear = k_tot_cloud - cloud_ext[:, :, None] zeros = jnp.zeros_like(cloud_ext) opacity_clear = dict(opacity_components) opacity_clear["cloud"] = zeros opacity_clear["cloud_ssa"] = zeros opacity_clear["cloud_g"] = zeros lw_up_clear, layer_contrib_clear = _lw_up_for_components(opacity_clear, k_tot_clear) lw_up = f_cloud * lw_up_cloud + (1.0 - f_cloud) * lw_up_clear layer_contrib_flux = f_cloud * layer_contrib_cloud + (1.0 - f_cloud) * layer_contrib_clear else: lw_up = lw_up_cloud layer_contrib_flux = layer_contrib_cloud s_dilute = jnp.asarray(params.get("s_dilute", 1.0)) top_flux = s_dilute * lw_up[0] if state.get("is_brown_dwarf", False): R0 = jnp.asarray(state["R0"], dtype=top_flux.dtype) D = params["D"] f_p = 10.0 ** jnp.asarray(params.get("log_10_f_p", 0.0), dtype=top_flux.dtype) distance = D * pc final_spectrum = top_flux * f_p * (R0 / distance) ** 2 else: final_spectrum = _scale_flux_ratio(top_flux, state, params) if contri_func: layer_contrib = jnp.clip(layer_contrib_flux, 0.0) contrib_func_norm = layer_contrib / jnp.maximum(layer_contrib.sum(axis=0, keepdims=True), 1e-30) else: contrib_func_norm = jnp.zeros((state["dz"].shape[0], final_spectrum.shape[0]), dtype=final_spectrum.dtype) return final_spectrum, contrib_func_norm