Source code for exo_skryer.instru_convolve

"""
instru_convolve.py
==================
"""

from __future__ import annotations

import jax.numpy as jnp

from .registry_bandpass import (
    bandpass_num_bins,
    bandpass_indices_padded,
    bandpass_coefficients_padded,
)

__all__ = [
    "apply_response_functions",
    "get_bandpass_cache",
    "apply_response_functions_cached",
]


def _convolve_spectrum_core(
    spec: jnp.ndarray,
    idx_pad: jnp.ndarray,
    coeff_pad: jnp.ndarray,
) -> jnp.ndarray:
    """Convolve high-resolution spectrum into observational bins.

    The bandpass registry precomputes the trapezoidal quadrature coefficients,
    including response weights, optional wavelength weighting, normalization,
    and the single-point-bin fallback. The hot path is therefore a gather and a
    weighted sum.

    Parameters
    ----------
    spec : `~jax.numpy.ndarray`, shape (nwl_hi,)
        High-resolution spectrum evaluated on the master wavelength grid.
    idx_pad : `~jax.numpy.ndarray`, shape (nbin, max_len)
        Padded indices into the high-resolution spectrum array. Maps each
        wavelength sample to its position in `spec`.
    coeff_pad : `~jax.numpy.ndarray`, shape (nbin, max_len)
        Padded linear coefficients for each bin.

    Returns
    -------
    binned_spectrum : `~jax.numpy.ndarray`, shape (nbin,)
        Convolved spectrum in observational bins.
    """
    spec_pad = jnp.take(spec, idx_pad, axis=0)  # (nbin, max_len)
    return jnp.sum(spec_pad * coeff_pad, axis=1)


[docs] def get_bandpass_cache() -> dict[str, jnp.ndarray]: """Materialize bandpass registry arrays into a single PyTree cache (outside jit).""" n_bins = bandpass_num_bins() if n_bins == 0: empty_f = jnp.zeros((0, 0), dtype=jnp.float64) empty_i = jnp.zeros((0, 0), dtype=jnp.int32) return { "idx_pad": empty_i, "coeff_pad": empty_f, } return { "idx_pad": bandpass_indices_padded(), "coeff_pad": bandpass_coefficients_padded(), }
[docs] def apply_response_functions_cached(spectrum: jnp.ndarray, cache: dict[str, jnp.ndarray]) -> jnp.ndarray: """Convolve spectrum using a provided bandpass cache (jit-friendly).""" coeff_pad = cache["coeff_pad"] if coeff_pad.size == 0: return jnp.zeros((0,), dtype=spectrum.dtype) return _convolve_spectrum_core( spec=spectrum, idx_pad=cache["idx_pad"], coeff_pad=coeff_pad, )
[docs] def apply_response_functions(spectrum: jnp.ndarray) -> jnp.ndarray: """Apply instrument response functions to convolve spectrum onto observational bins. This function takes a high-resolution model spectrum and convolves it with pre-loaded instrument response functions to produce a binned spectrum matching the observational wavelength grid. The response functions (boxcar, Gaussian, filter throughput curves, etc.) are retrieved from the bandpass registry. For boxcar bins, the integration is simple averaging: F_bin[i] = ∫ F(λ) dλ / ∫ dλ For filter curve bins (non-boxcar), the integration is photon-weighted: F_bin[i] = ∫ F(λ) T(λ) λ dλ / ∫ T(λ) λ dλ Parameters ---------- spectrum : `~jax.numpy.ndarray`, shape (nwl_hi,) High-resolution model spectrum evaluated on the master wavelength grid. This should be the output from a radiative transfer calculation. Returns ------- binned_spectrum : `~jax.numpy.ndarray`, shape (nbin,) Convolved spectrum in observational bins. If no bins are registered (nbin=0), returns an empty array with the same dtype as `spectrum`. """ return apply_response_functions_cached(spectrum, get_bandpass_cache())