Source code for exo_skryer.registry_bandpass

"""
instru_bandpass.py
==================
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Tuple, List

import numpy as np
import jax.numpy as jnp
from scipy.integrate import simpson


__all__ = [
    "BinConvolutionEntry",
    "reset_bandpass_registry",
    "has_bandpass_data",
    "load_bandpass_registry",
    "bandpass_num_bins",
    "bandpass_bin_edges",
    "bandpass_wavelengths_padded",
    "bandpass_weights_padded",
    "bandpass_indices_padded",
    "bandpass_coefficients_padded",
    "bandpass_norms",
    "bandpass_valid_lengths",
    "bandpass_is_boxcar",
]

# --- Dataclass and global registries ---


[docs] @dataclass(frozen=True) class BinConvolutionEntry: """ Holds information needed to convolve a single observational bin at runtime. Note: During preprocessing, all arrays are NumPy (CPU) They get converted to JAX (device) only at the final cache creation step All arrays kept as float64 for maximum accuracy in bandpass convolution """ method: str wavelengths: np.ndarray # NumPy during preprocessing (float64) - Slice of the high-res wavelength grid weights: np.ndarray # NumPy during preprocessing (float64) - Corresponding weights for the slice norm: float # Pre-calculated normalization constant for the bin indices: Tuple[int, int] # (start, end) index into the CUT wavelength grid bin_edges: Tuple[float, float] # Intended left and right edges of the bin
# Global entries and JAX caches _BAND_ENTRIES: Tuple[BinConvolutionEntry, ...] = () _BAND_WL_PAD_CACHE: jnp.ndarray | None = None _BAND_W_PAD_CACHE: jnp.ndarray | None = None _BAND_IDX_PAD_CACHE: jnp.ndarray | None = None _BAND_COEFF_PAD_CACHE: jnp.ndarray | None = None _BAND_NORM_CACHE: jnp.ndarray | None = None _BAND_VALID_LENS_CACHE: jnp.ndarray | None = None # Valid (non-padded) length for each bin _BAND_BOXCAR_CACHE: jnp.ndarray | None = None # Boxcar detection flags for each bin # Map instrument modes to filter filenames _MODE_TO_FILE = { "S36": "Spitzer_irac1_bandpass.dat", "S45": "Spitzer_irac2_bandpass.dat", "MIRI_4A": "JWST_MIRI_F1800W.dat", } # --- Internal helpers --- def _clear_cache(): """Clear all lru_cache-powered helper functions.""" bandpass_num_bins.cache_clear() bandpass_bin_edges.cache_clear() bandpass_wavelengths_padded.cache_clear() bandpass_weights_padded.cache_clear() bandpass_indices_padded.cache_clear() bandpass_coefficients_padded.cache_clear() bandpass_norms.cache_clear() bandpass_valid_lengths.cache_clear() bandpass_is_boxcar.cache_clear()
[docs] def reset_bandpass_registry(): """ Reset all bandpass-related registries and caches. """ global _BAND_ENTRIES, _BAND_WL_PAD_CACHE, _BAND_W_PAD_CACHE, _BAND_IDX_PAD_CACHE, _BAND_COEFF_PAD_CACHE, _BAND_NORM_CACHE, _BAND_VALID_LENS_CACHE, _BAND_BOXCAR_CACHE _BAND_ENTRIES = () _BAND_WL_PAD_CACHE = None _BAND_W_PAD_CACHE = None _BAND_IDX_PAD_CACHE = None _BAND_COEFF_PAD_CACHE = None _BAND_NORM_CACHE = None _BAND_VALID_LENS_CACHE = None _BAND_BOXCAR_CACHE = None _clear_cache()
[docs] def has_bandpass_data() -> bool: """ Returns True if the bandpass registry has been initialised. """ return bool(_BAND_ENTRIES)
def get_filter_data(mode: str) -> Tuple[np.ndarray, np.ndarray]: """ Loads and caches the raw transmission data (wavelength, throughput) for a given filter mode. Returns ------- wl_filter : np.ndarray Wavelengths of the filter transmission curve. throughput : np.ndarray Corresponding throughput values (dimensionless). """ base_dir = Path(__file__).resolve().parent.parent / "telescope_data" if mode not in _MODE_TO_FILE: raise FileNotFoundError(f"No transmission file is mapped for filter mode '{mode}'.") path = base_dir / _MODE_TO_FILE[mode] if not path.exists(): raise FileNotFoundError(f"Transmission file not found: {path}") print(f"[info] Loading filter data for '{mode}' from {path}") rows = [] with path.open("r", encoding="utf-8") as handle: for line in handle: stripped = line.strip() if not stripped or stripped.startswith("#"): continue parts = stripped.split() if len(parts) < 2: continue try: wl = float(parts[0]) throughput = float(parts[1]) rows.append((wl, throughput)) except ValueError: # Skip non-numeric header lines like "Photon counter" continue if not rows: raise ValueError(f"No valid transmission data found in {path}") data = np.asarray(rows, dtype=float) return data[:, 0], data[:, 1] # --- Main preparation function (NumPy → JAX) ---
[docs] def load_bandpass_registry( obs: dict, full_grid: np.ndarray, cut_grid: np.ndarray, ) -> None: """ Build the bandpass registry and JAX-ready padded arrays for each observational bin. Parameters ---------- obs : dict Observation info, must contain: - 'wl' : observed central wavelengths (1D array) - 'dwl': half-widths of each bin (1D array) - 'response_mode': array of strings / identifiers for each bin (e.g. "boxcar", "S36", "S45"). full_grid : `~numpy.ndarray` Full high-resolution wavelength grid (currently unused but kept for API stability). cut_grid : `~numpy.ndarray` Cut high-resolution wavelength grid on which convolution will be performed. """ global _BAND_ENTRIES, _BAND_WL_PAD_CACHE, _BAND_W_PAD_CACHE, _BAND_IDX_PAD_CACHE, _BAND_COEFF_PAD_CACHE, _BAND_NORM_CACHE, _BAND_VALID_LENS_CACHE, _BAND_BOXCAR_CACHE wl_hi = np.asarray(cut_grid, dtype=float) # high-res grid used for convolution wl_obs = np.asarray(obs["wl"], dtype=float) dwl_obs = np.asarray(obs["dwl"], dtype=float) response_modes = np.asarray(obs["response_mode"]) nobs = wl_obs.shape[0] entries: List[BinConvolutionEntry] = [] # --- First pass: build per-bin entries on irregular slices (NumPy) --- for idx in range(nobs): # Mode for this bin if response_modes.size: mode = str(response_modes[idx]).strip() else: mode = "boxcar" final_method = "boxcar" if (not mode or mode.lower() == "boxcar") else mode # Pre-load filter data if needed; fallback to boxcar on failure filter_wl = filter_throughput = None if final_method.lower() != "boxcar": try: filter_wl, filter_throughput = get_filter_data(final_method) except FileNotFoundError as e: print(f"[warn] {e}. Skipping bin {idx} and treating as boxcar.") final_method = "boxcar" # Bin edges center = wl_obs[idx] half_width = dwl_obs[idx] low, high = center - half_width, center + half_width # Index range in high-res grid start_idx = np.searchsorted(wl_hi, low, side="left") end_idx = np.searchsorted(wl_hi, high, side="right") # Fallback: if the bin is empty or completely outside the grid, use nearest point if start_idx >= end_idx: nearest = np.abs(wl_hi - center).argmin() start_idx = nearest end_idx = nearest + 1 wl_slice = wl_hi[start_idx:end_idx] if wl_slice.size == 0: # Very defensive; shouldn't happen given the fallback above. print(f"[warn] Empty wavelength slice for bin {idx}, using nearest grid point.") nearest = np.abs(wl_hi - center).argmin() wl_slice = wl_hi[nearest:nearest + 1] start_idx, end_idx = nearest, nearest + 1 # Weights if final_method.lower() == "boxcar": weights_slice = np.ones_like(wl_slice) else: weights_slice = np.interp( wl_slice, filter_wl, filter_throughput, left=0.0, right=0.0 ) # Norm calculation depends on filter type: # - Boxcar: ∫ dλ (simple width, already handled by integration) # - Non-boxcar: ∫ T(λ) λ dλ (photon-weighted for energy-counting detectors) if wl_slice.size > 1: if final_method.lower() == "boxcar": # Boxcar: simple integration ∫ w(λ) dλ = ∫ 1 dλ norm = simpson(weights_slice, x=wl_slice) else: # Non-boxcar filters: photon-weighted ∫ T(λ) λ dλ norm = simpson(weights_slice * wl_slice, x=wl_slice) else: norm = 1.0 if norm <= 0.0: print( f"[warn] Non-positive norm ({norm}) for bin {idx} with method " f"{final_method!r}. Falling back to norm=1.0." ) norm = 1.0 entry = BinConvolutionEntry( method=final_method, wavelengths=wl_slice.astype(np.float64), # NumPy (float64) weights=weights_slice.astype(np.float64), # NumPy (float64) norm=np.asarray(norm, dtype=np.float64), # Handle both scalar and array indices=(int(start_idx), int(end_idx)), bin_edges=(float(low), float(high)), ) entries.append(entry) _BAND_ENTRIES = tuple(entries) if not _BAND_ENTRIES: # Nothing to do, clear everything reset_bandpass_registry() return # --- Second pass: build padded rectangular arrays (NumPy) and convert to JAX --- n_bins = len(_BAND_ENTRIES) max_len = max(int(e.wavelengths.shape[0]) for e in _BAND_ENTRIES) padded_wl = np.zeros((n_bins, max_len), dtype=float) padded_w = np.zeros((n_bins, max_len), dtype=float) padded_idx = np.zeros((n_bins, max_len), dtype=int) padded_coeff = np.zeros((n_bins, max_len), dtype=float) norms_np = np.zeros((n_bins,), dtype=float) valid_lens_np = np.zeros((n_bins,), dtype=int) # Store valid (non-padded) length is_boxcar_np = np.zeros((n_bins,), dtype=bool) # Boxcar detection flags for i, e in enumerate(_BAND_ENTRIES): wl = np.asarray(np.array(e.wavelengths), dtype=float) w = np.asarray(np.array(e.weights), dtype=float) start, end = e.indices length = wl.size idxs = np.arange(start, end, dtype=int) # Fill valid part padded_wl[i, :length] = wl padded_w[i, :length] = w padded_idx[i, :length] = idxs if length == 1: # Match the runtime fallback for single-point bins exactly. padded_coeff[i, 0] = 1.0 else: trap_weights = np.empty(length, dtype=float) trap_weights[0] = 0.5 * (wl[1] - wl[0]) trap_weights[-1] = 0.5 * (wl[-1] - wl[-2]) if length > 2: trap_weights[1:-1] = 0.5 * (wl[2:] - wl[:-2]) lambda_weight = 1.0 if e.method.lower() == "boxcar" else wl padded_coeff[i, :length] = trap_weights * w * lambda_weight / max(float(e.norm), 1e-99) # Pad tail: copy last wavelength, set weights=0, repeat last index if length < max_len: padded_wl[i, length:] = wl[-1] padded_w[i, length:] = 0.0 padded_idx[i, length:] = idxs[-1] norms_np[i] = float(e.norm) valid_lens_np[i] = length # Store the valid length for this bin is_boxcar_np[i] = (e.method.lower() == "boxcar") # Detect boxcar bins # Keep legacy diagnostic arrays as NumPy until their accessors are called. # Only the runtime hot-path arrays are transferred eagerly. print(f"[Bandpass] Preparing {n_bins} bins...") _BAND_WL_PAD_CACHE = padded_wl _BAND_W_PAD_CACHE = padded_w _BAND_IDX_PAD_CACHE = jnp.asarray(padded_idx, dtype=jnp.int32) _BAND_COEFF_PAD_CACHE = jnp.asarray(padded_coeff, dtype=jnp.float64) _BAND_NORM_CACHE = norms_np _BAND_VALID_LENS_CACHE = valid_lens_np _BAND_BOXCAR_CACHE = is_boxcar_np print(f"[Bandpass] Wavelength cache: {_BAND_WL_PAD_CACHE.shape} (host dtype: {_BAND_WL_PAD_CACHE.dtype})") print(f"[Bandpass] Weights cache: {_BAND_W_PAD_CACHE.shape} (host dtype: {_BAND_W_PAD_CACHE.dtype})") print(f"[Bandpass] Index cache: {_BAND_IDX_PAD_CACHE.shape} (dtype: {_BAND_IDX_PAD_CACHE.dtype})") print(f"[Bandpass] Coeff cache: {_BAND_COEFF_PAD_CACHE.shape} (dtype: {_BAND_COEFF_PAD_CACHE.dtype})") print(f"[Bandpass] Norm cache: {_BAND_NORM_CACHE.shape} (dtype: {_BAND_NORM_CACHE.dtype})") print(f"[Bandpass] Valid lengths cache: {_BAND_VALID_LENS_CACHE.shape} (dtype: {_BAND_VALID_LENS_CACHE.dtype})") # Estimate memory usage wl_mb = padded_wl.size * padded_wl.itemsize / 1024**2 w_mb = padded_w.size * padded_w.itemsize / 1024**2 idx_mb = _BAND_IDX_PAD_CACHE.size * _BAND_IDX_PAD_CACHE.itemsize / 1024**2 coeff_mb = _BAND_COEFF_PAD_CACHE.size * _BAND_COEFF_PAD_CACHE.itemsize / 1024**2 total_device_mb = idx_mb + coeff_mb total_host_mb = wl_mb + w_mb + norms_np.size * norms_np.itemsize / 1024**2 + valid_lens_np.size * valid_lens_np.itemsize / 1024**2 print(f"[Bandpass] Estimated device memory: {total_device_mb:.3f} MB (idx: {idx_mb:.3f}, coeff: {coeff_mb:.3f}); host legacy cache: {total_host_mb:.3f} MB") _clear_cache()
# --- lru_cache helper accessors (JAX-ready) ---
[docs] @lru_cache(None) def bandpass_num_bins() -> int: """ Number of observational bins in the bandpass registry. """ return len(_BAND_ENTRIES)
[docs] @lru_cache(None) def bandpass_bin_edges() -> jnp.ndarray: """ Bin edges as an array of shape (n_bins, 2): [λ_low, λ_high] for each bin. """ if not _BAND_ENTRIES: raise RuntimeError("Bandpass registry empty; call load_bandpass_registry() first.") edges = np.array([e.bin_edges for e in _BAND_ENTRIES], dtype=float) return jnp.asarray(edges)
[docs] @lru_cache(None) def bandpass_wavelengths_padded() -> jnp.ndarray: """ Padded wavelength grid for each bin, shape (n_bins, max_len). """ if _BAND_WL_PAD_CACHE is None: raise RuntimeError("Bandpass padded arrays not built; call load_bandpass_registry() first.") return jnp.asarray(_BAND_WL_PAD_CACHE, dtype=jnp.float64)
[docs] @lru_cache(None) def bandpass_weights_padded() -> jnp.ndarray: """ Padded weights for each bin, shape (n_bins, max_len). """ if _BAND_W_PAD_CACHE is None: raise RuntimeError("Bandpass padded arrays not built; call load_bandpass_registry() first.") return jnp.asarray(_BAND_W_PAD_CACHE, dtype=jnp.float64)
[docs] @lru_cache(None) def bandpass_indices_padded() -> jnp.ndarray: """ Padded index array into the high-res spectrum grid, shape (n_bins, max_len). """ if _BAND_IDX_PAD_CACHE is None: raise RuntimeError("Bandpass padded arrays not built; call load_bandpass_registry() first.") return _BAND_IDX_PAD_CACHE
[docs] @lru_cache(None) def bandpass_coefficients_padded() -> jnp.ndarray: """ Padded linear convolution coefficients, shape (n_bins, max_len). Multiplying these coefficients by the gathered high-resolution spectrum and summing over axis=1 reproduces the runtime trapezoidal integration used by the original convolution core. """ if _BAND_COEFF_PAD_CACHE is None: raise RuntimeError("Bandpass coefficient arrays not built; call load_bandpass_registry() first.") return _BAND_COEFF_PAD_CACHE
[docs] @lru_cache(None) def bandpass_norms() -> jnp.ndarray: """ Normalisation constants for each bin, shape (n_bins,). """ if _BAND_NORM_CACHE is None: raise RuntimeError("Bandpass norms not built; call load_bandpass_registry() first.") return jnp.asarray(_BAND_NORM_CACHE, dtype=jnp.float64)
[docs] @lru_cache(None) def bandpass_valid_lengths() -> jnp.ndarray: """ Valid (non-padded) length for each bin, shape (n_bins,). """ if _BAND_VALID_LENS_CACHE is None: raise RuntimeError("Bandpass valid lengths not built; call load_bandpass_registry() first.") return jnp.asarray(_BAND_VALID_LENS_CACHE, dtype=jnp.int32)
[docs] @lru_cache(None) def bandpass_is_boxcar() -> jnp.ndarray: """ Boxcar detection flags for each bin, shape (n_bins,). Returns True for bins with uniform boxcar weights (constant response = 1.0), and False for bins with custom filter throughput curves. This flag is used to optimize convolution: boxcar bins can use simple averaging instead of numerical integration, which is significantly faster. Returns ------- is_boxcar : `~jax.numpy.ndarray`, shape (n_bins,), dtype bool True for boxcar bins, False for custom filter bins. """ if _BAND_BOXCAR_CACHE is None: raise RuntimeError("Bandpass boxcar flags not built; call load_bandpass_registry() first.") return jnp.asarray(_BAND_BOXCAR_CACHE, dtype=jnp.bool_)