Source code for exo_skryer.RT_alb_1D

"""
RT_alb_1D.py
============
"""

from __future__ import annotations

from typing import Dict, Mapping, Tuple

import jax.numpy as jnp

from . import build_opacities as XS

__all__ = ["compute_albedo_spectrum_1d"]


def _get_ck_weights(state: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    g_weights = state.get("g_weights")
    if g_weights is not None:
        return g_weights
    if not XS.has_ck_data():
        raise RuntimeError("c-k g-weights not built; run build_opacities() with ck tables.")
    g_weights = XS.ck_g_weights()
    if g_weights.ndim > 1:
        g_weights = g_weights[0]
    return g_weights


def _sum_opacity_components_ck(
    state: Dict[str, jnp.ndarray],
    opacity_components: Mapping[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Return the summed opacity grid for correlated-k mode.

    Line opacity has shape (nlay, nwl, ng).
    Other opacities have shape (nlay, nwl) and are broadcast over g-dimension.
    Returns shape (nlay, nwl, ng).
    """
    nlay = state["nlay"]
    nwl = state["nwl"]

    line_opacity = opacity_components.get("line")
    if line_opacity is None:
        ng = _get_ck_weights(state).shape[-1]
        line_opacity = jnp.zeros((nlay, nwl, ng))

    zeros_2d = jnp.zeros((nlay, nwl), dtype=line_opacity.dtype)
    component_keys_2d = ("rayleigh", "cia", "special", "cloud")
    components_2d = jnp.stack([opacity_components.get(k, zeros_2d) for k in component_keys_2d], axis=0)
    summed_2d = jnp.sum(components_2d, axis=0)

    return line_opacity + summed_2d[:, :, None]


def _sum_opacity_components_os(
    state: Dict[str, jnp.ndarray],
    opacity_components: Mapping[str, jnp.ndarray],
) -> jnp.ndarray:
    """Return the summed opacity grid for all provided components (OS mode)."""
    nlay = state["nlay"]
    nwl = state["nwl"]

    if not opacity_components:
        return jnp.zeros((nlay, nwl))

    component_keys = ("line", "rayleigh", "cia", "special", "cloud")
    first = next((opacity_components.get(k) for k in component_keys if k in opacity_components), None)
    if first is None:
        return jnp.zeros((nlay, nwl))

    zeros = jnp.zeros_like(first)
    stacked = jnp.stack([opacity_components.get(k, zeros) for k in component_keys], axis=0)
    return jnp.sum(stacked, axis=0)


[docs] def compute_albedo_spectrum_1d( state: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray], opacity_components: Mapping[str, jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Placeholder shortwave albedo RT kernel. Parameters ---------- state Standard forward-model state dict (see `build_model.forward_model`). Expected keys include `wl`, `nwl`, `nlay`, `ck`, and optionally `contri_func`. params Retrieval parameters (may include cloud patchiness, phase angle, etc.). opacity_components Opacity component mapping (line/rayleigh/cia/special/cloud and cloud scattering props). Returns ------- (alb_spectrum, contrib_func) `alb_spectrum`: (nwl,) reflected-light metric (TBD: geometric albedo vs reflectance). `contrib_func`: (nlay, nwl) normalized contribution function if enabled, else zeros. """ _ = (state, params, opacity_components) raise NotImplementedError( "compute_albedo_spectrum_1d is a placeholder. " "Implement a shortwave scattering/reflection solver and define the albedo convention." )