Source code for exo_skryer.run_retrieval

"""
run_retrieval.py
================
"""

import os
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")  # suppress XLA/Eigen bitcode warnings
import time
import argparse
from pathlib import Path
from typing import Any, Dict

import numpy as np

from .help_print import format_duration


__all__ = [
    "main",
    "format_duration",
]


[docs] def main() -> None: """Run a retrieval defined by a YAML configuration. This function coordinates reading configuration and data, preparing opacities and instrument responses, building the forward model, running the sampler, and saving outputs to the experiment directory. Returns ------- None """ # Start runtime counter t_start = time.perf_counter() # Parse YAML config file from command line # Format is --config /path/to/config.yaml p = argparse.ArgumentParser() p.add_argument("--config", required=True, help="Path to YAML config file") args = p.parse_args() # Print process ID (for easy kill) process_id = os.getpid() print("[info] Process ID: ", process_id) # Resolve experiment folder (for read & write) config_path = Path(args.config).resolve() exp_dir = config_path.parent # Load YAML parameters - make into a dot namespace from .read_yaml import read_yaml cfg = read_yaml(config_path) # Platform selection from YAML (used for sampler setup). We keep CPU environment # variables at JAX defaults; GPU-specific flags remain configurable. platform = str(getattr(cfg.runtime, "platform", "cpu")).lower() if platform in ("gpu", "cuda"): cuda_devices = str(getattr(cfg.runtime, "cuda_visible_devices", "")) if cuda_devices: os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices os.environ["JAX_PLATFORMS"] = "cuda" os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.75") tf_gpu_allocator = getattr(cfg.runtime, "tf_gpu_allocator", None) if tf_gpu_allocator: os.environ.setdefault("TF_GPU_ALLOCATOR", str(tf_gpu_allocator)) xla_flags = ( "--xla_gpu_enable_latency_hiding_scheduler=true " "--xla_gpu_enable_highest_priority_async_stream=true " "--xla_gpu_enable_fast_min_max=true " "--xla_gpu_deterministic_ops=true" ) os.environ["XLA_FLAGS"] = xla_flags print(f"[info] Platform: GPU (CUDA_VISIBLE_DEVICES={cuda_devices})") print("[info] XLA GPU: latency hiding, async streams, fast math enabled") elif platform == "metal": os.environ["JAX_PLATFORMS"] = "metal" os.environ["CUDA_VISIBLE_DEVICES"] = "" print("[info] Platform: Metal (Apple GPU)") else: # cpu os.environ["JAX_PLATFORMS"] = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = "" xla_flags_cpu = ( "--xla_cpu_multi_thread_eigen=true " f"intra_op_parallelism_threads={cfg.runtime.threads}" ) existing_xla_flags = os.environ.get("XLA_FLAGS", "").strip() os.environ["XLA_FLAGS"] = " ".join( part for part in (existing_xla_flags, xla_flags_cpu) if part ) print(f"[info] Platform: CPU (threads={cfg.runtime.threads})") # Print main yaml parameters to command line (after setting environment) from .help_print import print_cfg print_cfg(cfg) # Prepare JAX and numpyro JAX settings # On non-GPU platforms JAX still discovers and tries to initialise every # installed CUDA plugin, which logs a noisy ERROR when cuInit returns # CUDA_ERROR_NO_DEVICE. Silence that specific logger around the import, # then restore the original level so nothing else is suppressed. import logging _xla_bridge_log = logging.getLogger("jax._src.xla_bridge") if platform not in ("gpu", "cuda"): _prev_xla_level = _xla_bridge_log.level _xla_bridge_log.setLevel(logging.CRITICAL) from jax import config as jax_config jax_config.update("jax_enable_x64", True) #jax_config.update("jax_debug_nans", True) # import numpyro # numpyro.enable_x64(True) # numpyro.set_platform(platform) # if platform == "cpu": # numpyro.set_host_device_count(cfg.runtime.threads) import jax if platform not in ("gpu", "cuda"): _xla_bridge_log.setLevel(_prev_xla_level) # Print the JAX setup print(f"[info] JAX backend: {jax.default_backend()}") print(f"[info] JAX devices: {jax.local_device_count()} {jax.devices()}") # Load the observational data - return a dictionary obs from .read_obs import resolve_obs_path, read_obs_data from .read_stellar import read_stellar_spectrum obs_spec = resolve_obs_path(cfg) obs = read_obs_data(obs_spec, base_dir=exp_dir) # Load the opacities (if present in YAML file) from .build_opacities import build_opacities, master_wavelength, master_wavelength_cut build_opacities(cfg, obs, exp_dir) full_grid = np.asarray(master_wavelength(), dtype=float) cut_grid = np.asarray(master_wavelength_cut(), dtype=float) print( f"[info] Master grid: N={full_grid.size}, range=[{full_grid.min():.5f}, {full_grid.max():.5f}]" ) print( f"[info] Cut grid: N={cut_grid.size}, range=[{cut_grid.min():.5f}, {cut_grid.max():.5f}]" ) # Read and prepare any response functions and bandpasses for each observational band from .registry_bandpass import load_bandpass_registry load_bandpass_registry(obs, full_grid, cut_grid) # Initialize chemistry backends with experiment-relative paths before the # forward model's direct-use fallback runs. from .build_chem import ( load_nasa9_if_needed, init_atmodeller_if_needed, ) load_nasa9_if_needed(cfg, exp_dir) init_atmodeller_if_needed(cfg, exp_dir) # Build the forward model from the YAML options - return a function that samplers can use from .build_model import build_forward_model stellar_flux = read_stellar_spectrum(cfg, cut_grid, bool(cfg.opac.ck), base_dir=exp_dir) fm_fnc = build_forward_model(cfg, obs, stellar_flux=stellar_flux) # All samplers are now self-contained - no build_prepared needed! t_start_2 = time.perf_counter() # Prepare the sampling schemes print(f"[info] Starting Sampling") # Dictionaries for the output of the samplers evidence_info: Dict[str, Any] = {} samples_dict: Dict[str, Any] = {} engine = cfg.sampling.engine if engine == "nuts": # Extract backend driver backend = cfg.sampling.nuts.backend if backend == "blackjax": # Blackjax MCMC driver (self-contained) from .sampler_blackjax_MCMC import run_nuts_blackjax samples_dict = run_nuts_blackjax(cfg, obs, fm_fnc, exp_dir) elif backend == "numpyro": # Numpyro MCMC driver (self-contained) from .sampler_numpyro_MCMC import run_nuts_numpyro samples_dict = run_nuts_numpyro(cfg, obs, fm_fnc, exp_dir) else: raise ValueError(f"Unknown backend for NUTS: {backend!r}") elif engine == "jaxns": # jaxns nested-sampling driver (self-contained) from .sampler_jaxns_NS import run_nested_jaxns samples_dict, evidence_info = run_nested_jaxns(cfg, obs, fm_fnc, exp_dir) elif engine == "blackjax_ns": # BlackJAX nested-sampling driver (self-contained) from .sampler_blackjax_NS import run_nested_blackjax samples_dict, evidence_info = run_nested_blackjax(cfg, obs, fm_fnc, exp_dir) elif engine == "ultranest": # UltraNest nested-sampling driver (self-contained) from .sampler_ultranest_NS import run_nested_ultranest samples_dict, evidence_info = run_nested_ultranest(cfg, obs, fm_fnc, exp_dir) elif engine == "dynesty": # Dynesty nested-sampling driver (self-contained) from .sampler_dynesty_NS import run_nested_dynesty samples_dict, evidence_info = run_nested_dynesty(cfg, obs, fm_fnc, exp_dir) elif engine == "pymultinest": # PyMultiNest nested-sampling driver (self-contained) from .sampler_pymultinest_NS import run_nested_pymultinest samples_dict, evidence_info = run_nested_pymultinest(cfg, obs, fm_fnc, exp_dir) elif engine == "nautilus": # Nautilus nested-sampling driver (self-contained) from .sampler_nautilus_NS import run_nested_nautilus samples_dict, evidence_info = run_nested_nautilus(cfg, obs, fm_fnc, exp_dir) else: raise ValueError(f"Unknown sampling.engine: {engine!r}. Options: nuts, jaxns, blackjax_ns, ultranest, dynesty, pymultinest, nautilus") print(f"[info] Finished Sampling") t_end_2 = time.perf_counter() print(f"[info] Sampling took:", format_duration(t_end_2 - t_start_2)) # Output from .help_io import to_inferencedata, save_inferencedata, save_observed_data_csv samples_np = {k: np.asarray(v) for k, v in samples_dict.items()} idata = to_inferencedata(samples_np, cfg, include_fixed=False) out_nc = save_inferencedata(idata, exp_dir) print(f"[info] ArviZ posterior -> {out_nc}") # Save a copy of the observational data to a csv in the experiment directory save_observed_data_csv( exp_dir, lam=obs["wl"], dlam=obs["dwl"], y=obs["y"], dy=obs["dy"], response_mode=obs.get("response_mode"), offset_group=obs.get("offset_group"), ) print(f"[info] Results saved to: {exp_dir.resolve()}") # Print overall runtime t_end = time.perf_counter() print(f"[done] Full model took:", format_duration(t_end - t_start))
# Calling function if __name__ == "__main__": main()