"""
read_stellar.py
================
"""
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Optional
import jax.numpy as jnp
import numpy as np
from scipy.integrate import simpson
__all__ = ["read_stellar_spectrum"]
def _resolve_stellar_path(cfg, base_dir: Optional[Path]) -> Path | None:
data_cfg = getattr(cfg, "data", None)
stellar_path = getattr(data_cfg, "stellar", None) if data_cfg is not None else None
if stellar_path is None:
obs_cfg = getattr(cfg, "obs", None)
if obs_cfg is not None:
stellar_path = getattr(obs_cfg, "stellar", None)
if not stellar_path:
return None
path = Path(stellar_path).expanduser()
if not path.is_absolute():
if base_dir is not None:
path = (Path(base_dir) / path).resolve()
else:
path = (Path.cwd() / path).resolve()
return path
def _load_native_spectrum(path: Path) -> tuple[np.ndarray, np.ndarray]:
data = np.loadtxt(path, comments="#")
if data.ndim == 1 or data.shape[1] < 2:
raise ValueError(f"Stellar spectrum '{path}' must have at least two columns (wl, flux).")
wavelengths = np.asarray(data[:, 0], dtype=float)
flux = np.asarray(data[:, 1], dtype=float)
sort_idx = np.argsort(wavelengths)
return wavelengths[sort_idx], flux[sort_idx]
def _compute_bin_edges(lam_master: np.ndarray) -> np.ndarray:
edges = np.zeros(lam_master.size + 1, dtype=float)
edges[1:-1] = 0.5 * (lam_master[1:] + lam_master[:-1])
spacing_start = lam_master[1] - lam_master[0]
spacing_end = lam_master[-1] - lam_master[-2]
edges[0] = lam_master[0] - 0.5 * spacing_start
edges[-1] = lam_master[-1] + 0.5 * spacing_end
return edges
def _band_average(
wl_native: np.ndarray,
flux_native: np.ndarray,
edges: np.ndarray,
) -> np.ndarray:
# Use log10-space interpolation for stellar flux
log10_flux_native = np.log10(flux_native)
out = np.empty(edges.size - 1, dtype=float)
for i in range(out.size):
left = edges[i]
right = edges[i + 1]
# Include bin edges explicitly so the integral spans the full bin width.
interior_mask = (wl_native > left) & (wl_native < right)
wl_interior = wl_native[interior_mask]
wl_seg = np.concatenate(([left], wl_interior, [right]))
# Interpolate in log10-space, then convert back
log10_fl_seg = np.interp(
wl_seg,
wl_native,
log10_flux_native,
left=log10_flux_native[0],
right=log10_flux_native[-1],
)
fl_seg = 10.0 ** log10_fl_seg
out[i] = simpson(fl_seg, x=wl_seg) / (right - left)
return out
def _native_is_higher_resolution(
wl_native: np.ndarray,
lam_master: np.ndarray,
) -> bool:
if wl_native.size < 2 or lam_master.size < 2:
return False
wl_lo = max(wl_native[0], lam_master[0])
wl_hi = min(wl_native[-1], lam_master[-1])
if wl_hi <= wl_lo:
return False
native_mask = (wl_native >= wl_lo) & (wl_native <= wl_hi)
master_mask = (lam_master >= wl_lo) & (lam_master <= wl_hi)
wl_native_overlap = wl_native[native_mask]
wl_master_overlap = lam_master[master_mask]
if wl_native_overlap.size < 2 or wl_master_overlap.size < 2:
return False
native_dlam = np.diff(wl_native_overlap)
master_dlam = np.diff(wl_master_overlap)
native_dlam = native_dlam[native_dlam > 0.0]
master_dlam = master_dlam[master_dlam > 0.0]
if native_dlam.size == 0 or master_dlam.size == 0:
return False
# Treat native as higher-resolution when its typical spacing is finer.
return np.median(native_dlam) < np.median(master_dlam)
[docs]
def read_stellar_spectrum(
cfg,
lam_master: Iterable[float],
ck_mode: bool,
base_dir: Optional[Path] = None,
) -> jnp.ndarray | None:
"""Read and interpolate the stellar spectrum onto the master grid."""
path = _resolve_stellar_path(cfg, base_dir)
if path is None:
return None
if not path.exists():
print(f"[read_stellar] Stellar spectrum not found at {path}; skipping.")
return None
wl_native, flux_native = _load_native_spectrum(path)
lam_master = np.asarray(lam_master, dtype=float)
if ck_mode:
edges = _compute_bin_edges(lam_master)
flux_master = _band_average(wl_native, flux_native, edges)
mode_str = "ck_bin_avg"
else:
if _native_is_higher_resolution(wl_native, lam_master):
edges = _compute_bin_edges(lam_master)
flux_master = _band_average(wl_native, flux_native, edges)
mode_str = "os_bin_avg"
else:
# Interpolate in log10-space for better accuracy across orders of magnitude
log10_flux_native = np.log10(flux_native)
log10_flux_master = np.interp(
lam_master,
wl_native,
log10_flux_native,
left=log10_flux_native[0],
right=log10_flux_native[-1],
)
flux_master = 10.0 ** log10_flux_master
mode_str = "os_interp"
print(
"[read_stellar] Loaded stellar spectrum: "
f"path={path}, mode={mode_str}, "
f"native_N={wl_native.size}, master_N={lam_master.size}, "
f"wl_native=[{wl_native.min():.5g}, {wl_native.max():.5g}] um, "
f"wl_master=[{lam_master.min():.5g}, {lam_master.max():.5g}] um, "
f"flux_master=[{flux_master.min():.5e}, {flux_master.max():.5e}]"
)
return jnp.asarray(flux_master)