Source code for exo_skryer.read_stellar

"""
read_stellar.py
================
"""

from __future__ import annotations

from pathlib import Path
from typing import Iterable, Optional

import jax.numpy as jnp
import numpy as np
from scipy.integrate import simpson


__all__ = ["read_stellar_spectrum"]


def _resolve_stellar_path(cfg, base_dir: Optional[Path]) -> Path | None:
    data_cfg = getattr(cfg, "data", None)
    stellar_path = getattr(data_cfg, "stellar", None) if data_cfg is not None else None
    if stellar_path is None:
        obs_cfg = getattr(cfg, "obs", None)
        if obs_cfg is not None:
            stellar_path = getattr(obs_cfg, "stellar", None)
    if not stellar_path:
        return None
    path = Path(stellar_path).expanduser()
    if not path.is_absolute():
        if base_dir is not None:
            path = (Path(base_dir) / path).resolve()
        else:
            path = (Path.cwd() / path).resolve()
    return path


def _load_native_spectrum(path: Path) -> tuple[np.ndarray, np.ndarray]:
    data = np.loadtxt(path, comments="#")
    if data.ndim == 1 or data.shape[1] < 2:
        raise ValueError(f"Stellar spectrum '{path}' must have at least two columns (wl, flux).")
    wavelengths = np.asarray(data[:, 0], dtype=float)
    flux = np.asarray(data[:, 1], dtype=float)
    sort_idx = np.argsort(wavelengths)
    return wavelengths[sort_idx], flux[sort_idx]


def _compute_bin_edges(lam_master: np.ndarray) -> np.ndarray:
    edges = np.zeros(lam_master.size + 1, dtype=float)
    edges[1:-1] = 0.5 * (lam_master[1:] + lam_master[:-1])
    spacing_start = lam_master[1] - lam_master[0]
    spacing_end = lam_master[-1] - lam_master[-2]
    edges[0] = lam_master[0] - 0.5 * spacing_start
    edges[-1] = lam_master[-1] + 0.5 * spacing_end
    return edges


def _band_average(
    wl_native: np.ndarray,
    flux_native: np.ndarray,
    edges: np.ndarray,
) -> np.ndarray:
    # Use log10-space interpolation for stellar flux
    log10_flux_native = np.log10(flux_native)

    out = np.empty(edges.size - 1, dtype=float)
    for i in range(out.size):
        left = edges[i]
        right = edges[i + 1]
        # Include bin edges explicitly so the integral spans the full bin width.
        interior_mask = (wl_native > left) & (wl_native < right)
        wl_interior = wl_native[interior_mask]
        wl_seg = np.concatenate(([left], wl_interior, [right]))
        # Interpolate in log10-space, then convert back
        log10_fl_seg = np.interp(
            wl_seg,
            wl_native,
            log10_flux_native,
            left=log10_flux_native[0],
            right=log10_flux_native[-1],
        )
        fl_seg = 10.0 ** log10_fl_seg
        out[i] = simpson(fl_seg, x=wl_seg) / (right - left)
    return out


def _native_is_higher_resolution(
    wl_native: np.ndarray,
    lam_master: np.ndarray,
) -> bool:
    if wl_native.size < 2 or lam_master.size < 2:
        return False

    wl_lo = max(wl_native[0], lam_master[0])
    wl_hi = min(wl_native[-1], lam_master[-1])
    if wl_hi <= wl_lo:
        return False

    native_mask = (wl_native >= wl_lo) & (wl_native <= wl_hi)
    master_mask = (lam_master >= wl_lo) & (lam_master <= wl_hi)
    wl_native_overlap = wl_native[native_mask]
    wl_master_overlap = lam_master[master_mask]
    if wl_native_overlap.size < 2 or wl_master_overlap.size < 2:
        return False

    native_dlam = np.diff(wl_native_overlap)
    master_dlam = np.diff(wl_master_overlap)
    native_dlam = native_dlam[native_dlam > 0.0]
    master_dlam = master_dlam[master_dlam > 0.0]
    if native_dlam.size == 0 or master_dlam.size == 0:
        return False

    # Treat native as higher-resolution when its typical spacing is finer.
    return np.median(native_dlam) < np.median(master_dlam)


[docs] def read_stellar_spectrum( cfg, lam_master: Iterable[float], ck_mode: bool, base_dir: Optional[Path] = None, ) -> jnp.ndarray | None: """Read and interpolate the stellar spectrum onto the master grid.""" path = _resolve_stellar_path(cfg, base_dir) if path is None: return None if not path.exists(): print(f"[read_stellar] Stellar spectrum not found at {path}; skipping.") return None wl_native, flux_native = _load_native_spectrum(path) lam_master = np.asarray(lam_master, dtype=float) if ck_mode: edges = _compute_bin_edges(lam_master) flux_master = _band_average(wl_native, flux_native, edges) mode_str = "ck_bin_avg" else: if _native_is_higher_resolution(wl_native, lam_master): edges = _compute_bin_edges(lam_master) flux_master = _band_average(wl_native, flux_native, edges) mode_str = "os_bin_avg" else: # Interpolate in log10-space for better accuracy across orders of magnitude log10_flux_native = np.log10(flux_native) log10_flux_master = np.interp( lam_master, wl_native, log10_flux_native, left=log10_flux_native[0], right=log10_flux_native[-1], ) flux_master = 10.0 ** log10_flux_master mode_str = "os_interp" print( "[read_stellar] Loaded stellar spectrum: " f"path={path}, mode={mode_str}, " f"native_N={wl_native.size}, master_N={lam_master.size}, " f"wl_native=[{wl_native.min():.5g}, {wl_native.max():.5g}] um, " f"wl_master=[{lam_master.min():.5g}, {lam_master.max():.5g}] um, " f"flux_master=[{flux_master.min():.5e}, {flux_master.max():.5e}]" ) return jnp.asarray(flux_master)