Source code for exo_skryer.build_chem

"""
build_chem.py
=============
"""

from __future__ import annotations

import importlib.util
from pathlib import Path
from typing import Iterable, Any

from .limb_asymmetry import is_internal_parameter_name, split_limb_tag

__all__ = [
    'infer_trace_species',
    'infer_log10_vmr_keys',
    'infer_clr_keys',
    'validate_log10_vmr_params',
    'validate_clr_params',
    'clr_samples_to_vmr',
    'prepare_chemistry_kernel',
    'load_nasa9_if_needed',
    'init_fastchem_grid_if_needed',
    'init_element_potentials_if_needed',
    'init_atmodeller_if_needed',
    'init_quench_approx_if_needed',
]


def _extract_species_list(block) -> list[str]:
    if not block:
        return []
    if isinstance(block, bool):
        return []
    names: list[str] = []
    try:
        iterator = iter(block)
    except TypeError:
        iterator = iter((block,))
    for item in iterator:
        name = getattr(item, "species", item)
        names.append(str(name).strip())
    return names


def _append_unique(seq: list[str], name: str) -> None:
    name = str(name).strip()
    if not name:
        return
    if name not in seq:
        seq.append(name)


[docs] def infer_trace_species( cfg, line_opac_scheme_str: str, ray_opac_scheme_str: str, cia_opac_scheme_str: str, special_opac_scheme_str: str, ) -> tuple[str, ...]: required: list[str] = [] def add_many(names: Iterable[str]) -> None: for n in names: _append_unique(required, n) if line_opac_scheme_str.lower() == "os": add_many(_extract_species_list(getattr(cfg.opac, "line", None))) elif line_opac_scheme_str.lower() == "ck": ck_mode = getattr(cfg.opac, "ck", None) if isinstance(ck_mode, bool): ck_block = getattr(cfg.opac, "line", None) else: ck_block = ck_mode add_many(_extract_species_list(ck_block)) if ray_opac_scheme_str.lower() in ("os", "ck"): add_many(_extract_species_list(getattr(cfg.opac, "ray", None))) if cia_opac_scheme_str.lower() in ("os", "ck"): for cia_name in _extract_species_list(getattr(cfg.opac, "cia", None)): if cia_name == "H-": # H- is treated as special opacity (bound-free/free-free), not a CIA pair. continue parts = cia_name.split("-") if len(parts) == 2: # CIA pairs that include atomic hydrogen should not force H to be a # retrieved trace VMR; in constant-VMR modes H is derived from the # filler via log_10_H_over_H2. for p in (parts[0], parts[1]): if p in ("H2", "He", "H"): continue _append_unique(required, p) if special_opac_scheme_str.lower() not in ("none", "off", "false", "0"): opac_cfg = getattr(cfg, "opac", None) special_cfg = getattr(opac_cfg, "special", None) if opac_cfg is not None else None hm_enabled = False hm_bf = True hm_ff = False # New config: cfg.opac.special includes species='H-' with optional bf/ff toggles if special_cfg not in (None, "None", "none", False): hm_enabled = True if not isinstance(special_cfg, bool): try: iterator = iter(special_cfg) except TypeError: iterator = iter((special_cfg,)) for item in iterator: name = getattr(item, "species", item) if str(name).strip() != "H-": continue hm_bf = bool(getattr(item, "bf", hm_bf)) hm_ff = bool(getattr(item, "ff", hm_ff)) hm_enabled = True break # Back-compat: cfg.opac.cia includes H- (treated as bf only unless ff flag set) if not hm_enabled: cia_cfg = getattr(opac_cfg, "cia", None) if opac_cfg is not None else None if cia_cfg not in (None, "None", "none", False): try: iterator = iter(cia_cfg) except TypeError: iterator = iter((cia_cfg,)) for item in iterator: name = getattr(item, "species", item) if str(name).strip() != "H-": continue hm_enabled = True hm_bf = True hm_ff = bool(getattr(item, "ff", hm_ff)) break if hm_enabled and hm_bf: _append_unique(required, "H-") # For H- free-free: electrons and atomic-H are handled via separate # retrieved proxies (ne/n_tot and H/H2), not via the VMR machinery. # H2/He are always filler species, and atomic H is derived from the filler # when needed via the dedicated log_10_H_over_H2 parameter. trace_species = tuple(s for s in required if s not in ("H2", "He", "H")) return trace_species
def infer_active_opacity_species(cfg) -> tuple[str, ...]: """Infer active chemistry species directly required by configured opacities.""" required: list[str] = [] def add(name: str) -> None: _append_unique(required, name) def add_many(names: Iterable[str]) -> None: for n in names: add(n) line_mode = str(getattr(getattr(cfg, "physics", None), "opac_line", "none")).lower() if line_mode in ("os", "ck"): if line_mode == "ck": ck_cfg = getattr(getattr(cfg, "opac", None), "ck", None) block = getattr(cfg.opac, "line", None) if isinstance(ck_cfg, bool) else ck_cfg add_many(_extract_species_list(block)) else: add_many(_extract_species_list(getattr(cfg.opac, "line", None))) ray_mode = str(getattr(getattr(cfg, "physics", None), "opac_ray", "none")).lower() if ray_mode in ("os", "ck"): add_many(_extract_species_list(getattr(cfg.opac, "ray", None))) cia_mode = str(getattr(getattr(cfg, "physics", None), "opac_cia", "none")).lower() if cia_mode in ("os", "ck"): for cia_name in _extract_species_list(getattr(cfg.opac, "cia", None)): if cia_name == "H-": continue parts = cia_name.split("-") if len(parts) == 2: add(parts[0]) add(parts[1]) special_mode = str(getattr(getattr(cfg, "physics", None), "opac_special", "none")).lower() if special_mode not in ("none", "off", "false", "0"): special_cfg = getattr(getattr(cfg, "opac", None), "special", None) hm_enabled = False hm_bf = True hm_ff = False if special_cfg not in (None, "None", "none", False): try: iterator = iter(special_cfg) if not isinstance(special_cfg, bool) else iter(()) except TypeError: iterator = iter((special_cfg,)) for item in iterator: spec = str(getattr(item, "species", item)) if spec != "H-": continue hm_enabled = True hm_bf = bool(getattr(item, "bf", hm_bf)) hm_ff = bool(getattr(item, "ff", hm_ff)) break if hm_enabled and hm_bf: add("H-") if hm_enabled and hm_ff: add("H") add("e-") return tuple(required) def _special_hminus_requirements(cfg) -> tuple[bool, bool]: """Return (need_bf, need_ff) from physics/opac special configuration.""" special_mode = str(getattr(getattr(cfg, "physics", None), "opac_special", "none")).lower() if special_mode in ("none", "off", "false", "0"): return False, False need_bf = False need_ff = False special_cfg = getattr(getattr(cfg, "opac", None), "special", None) if special_cfg not in (None, "None", "none", False): try: iterator = iter(special_cfg) if not isinstance(special_cfg, bool) else iter(()) except TypeError: iterator = iter((special_cfg,)) for item in iterator: spec = str(getattr(item, "species", item)) if spec != "H-": continue need_bf = bool(getattr(item, "bf", True)) need_ff = bool(getattr(item, "ff", False)) return need_bf, need_ff # Backward-compatible behavior: if special is enabled but no explicit H- row, # treat bf as enabled by default. return True, False
[docs] def infer_log10_vmr_keys(trace_species: tuple[str, ...]) -> tuple[str, ...]: return tuple(f"log_10_f_{s}" for s in trace_species)
[docs] def infer_clr_keys(trace_species: tuple[str, ...]) -> tuple[str, ...]: """Return CLR parameter keys for each trace species. Parameters ---------- trace_species : tuple of str Ordered tuple of trace species names. Returns ------- tuple of str Parameter keys with ``clr_`` prefix, e.g. ``('clr_H2O', 'clr_CO')``. """ return tuple(f"clr_{s}" for s in trace_species)
def _cfg_param_base_names(cfg) -> set[str]: """Return config parameter names with any transit_1_5d limb tag stripped.""" base_names: set[str] = set() for p in getattr(cfg, "params", []): name = str(getattr(p, "name", "")) if not name or is_internal_parameter_name(name): continue base, _tag = split_limb_tag(name) base_names.add(base) return base_names
[docs] def validate_log10_vmr_params(cfg, trace_species: tuple[str, ...]) -> None: cfg_param_names = _cfg_param_base_names(cfg) missing = [s for s in trace_species if f"log_10_f_{s}" not in cfg_param_names] if missing: joined = ", ".join(missing) raise ValueError( f"Missing required VMR parameters for: {joined}. " f"Add `log_10_f_<species>` entries to cfg.params." )
[docs] def validate_clr_params(cfg, trace_species: tuple[str, ...]) -> bool: """Validate that abundance parameters are present for CLR mode. This function accepts either CLR parameters (``clr_<species>``) or traditional log10 VMR parameters (``log_10_f_<species>``). When log10 VMR parameters are used, they will be converted to CLR coordinates internally via softmax, which acts as a soft constraint ensuring valid atmospheric composition. Parameters ---------- cfg : config object Configuration object containing params list. trace_species : tuple of str Ordered tuple of trace species names. Returns ------- bool True if using log_10_f_* parameters, False if using clr_* parameters. Raises ------ ValueError If neither parameter style is found, or if mixing both styles. """ cfg_param_names = _cfg_param_base_names(cfg) # Check which parameter style is being used has_clr = any(f"clr_{s}" in cfg_param_names for s in trace_species) has_log10 = any(f"log_10_f_{s}" in cfg_param_names for s in trace_species) if has_clr and has_log10: raise ValueError( "Cannot mix clr_* and log_10_f_* parameters. " "Use either CLR parameters (clr_H2O, clr_CO, ...) " "or log10 VMR parameters (log_10_f_H2O, log_10_f_CO, ...)" ) if has_clr: # Validate all CLR parameters are present missing = [s for s in trace_species if f"clr_{s}" not in cfg_param_names] if missing: joined = ", ".join(missing) raise ValueError( f"Missing required CLR parameters for: {joined}. " f"Add `clr_<species>` entries to cfg.params." ) return False # Using native CLR parameters if has_log10: # Validate all log10 VMR parameters are present missing = [s for s in trace_species if f"log_10_f_{s}" not in cfg_param_names] if missing: joined = ", ".join(missing) raise ValueError( f"Missing required VMR parameters for: {joined}. " f"Add `log_10_f_<species>` entries to cfg.params." ) return True # Using log10 VMR parameters, will convert internally # Neither style found raise ValueError( f"No abundance parameters found for species: {', '.join(trace_species)}. " f"Add either `clr_<species>` or `log_10_f_<species>` entries to cfg.params." )
[docs] def clr_samples_to_vmr( samples_dict: dict[str, "np.ndarray"], species: tuple[str, ...] | list[str], ) -> dict[str, "np.ndarray"]: """Convert CLR posterior samples to physical VMR (and log10 VMR) columns. Applies the same softmax inverse as ``constant_vmr_clr`` but operates on NumPy arrays of posterior samples rather than JAX scalars. Parameters ---------- samples_dict : dict[str, np.ndarray] Posterior samples keyed by parameter name. Must contain ``clr_<species>`` for every species in *species*. species : tuple or list of str Ordered trace species names (e.g. ``('H2O', 'CO', 'CO2')``). Returns ------- derived : dict[str, np.ndarray] New columns ready to merge into the samples dictionary: - ``log_10_f_<species>`` : log10(VMR) for each trace species - ``f_<species>`` : linear VMR for each trace species - ``f_H2_He`` : combined H2+He filler fraction """ import numpy as np from scipy.special import logsumexp as _logsumexp species = tuple(species) n_samples = len(samples_dict[f"clr_{species[0]}"]) # (N_samples, N_species) matrix of CLR values z = np.column_stack([samples_dict[f"clr_{s}"] for s in species]) # Prepend z_filler = 0 column → (N_samples, N_species + 1) z_all = np.concatenate([np.zeros((n_samples, 1)), z], axis=1) # Softmax: VMR_i = exp(z_i) / sum(exp(z_j)) log_denom = _logsumexp(z_all, axis=1, keepdims=True) vmr_all = np.exp(z_all - log_denom) # First column is filler, rest are trace species (same order as input) derived: dict[str, np.ndarray] = {} derived["f_H2_He"] = vmr_all[:, 0] for i, s in enumerate(species): vmr_i = vmr_all[:, i + 1] derived[f"f_{s}"] = vmr_i derived[f"log_10_f_{s}"] = np.log10(np.maximum(vmr_i, 1e-300)) return derived
[docs] def prepare_chemistry_kernel(cfg, chemistry_kernel, opacity_schemes: dict): """Prepare and validate chemistry kernel with inferred species. This function infers the required trace species from the opacity configuration, validates that necessary parameters are present, and returns an optimized chemistry kernel ready for use in the forward model. Parameters ---------- cfg : config object Configuration object containing params and opac settings. chemistry_kernel : callable The base chemistry kernel function from vert_chem (e.g., constant_vmr, CE_fastchem_jax, CE_rate_jax). opacity_schemes : dict Dictionary containing opacity scheme strings with keys: 'line_opac', 'ray_opac', 'cia_opac', 'special_opac'. Returns ------- chemistry_kernel : callable The prepared chemistry kernel, potentially optimized for JIT compilation. trace_species : tuple of str Tuple of trace species names inferred from the opacity configuration. """ from .vert_chem import constant_vmr, constant_vmr_clr # Infer required species from opacity configuration trace_species = infer_trace_species( cfg, line_opac_scheme_str=opacity_schemes['line_opac'], ray_opac_scheme_str=opacity_schemes['ray_opac'], cia_opac_scheme_str=opacity_schemes['cia_opac'], special_opac_scheme_str=opacity_schemes['special_opac'], ) # If configured opacities require atomic H (CIA H2-H / He-H pairs, or H- ff), # constant-VMR chemistry derives H from the H2+He filler using log_10_H_over_H2. need_atomic_h = False opac_cfg = getattr(cfg, "opac", None) cia_names = _extract_species_list(getattr(opac_cfg, "cia", None) if opac_cfg is not None else None) for name in cia_names: parts = str(name).strip().split("-") if len(parts) == 2 and "H" in parts and "H-" not in parts: need_atomic_h = True break if not need_atomic_h: special_cfg = getattr(opac_cfg, "special", None) if opac_cfg is not None else None if special_cfg not in (None, "None", "none", False): try: iterator = iter(special_cfg) if not isinstance(special_cfg, bool) else iter(()) except TypeError: iterator = iter((special_cfg,)) for item in iterator: spec = getattr(item, "species", item) if str(spec).strip() != "H-": continue if bool(getattr(item, "ff", False)): need_atomic_h = True break if need_atomic_h: from .vert_chem import constant_vmr, constant_vmr_clr if chemistry_kernel in (constant_vmr, constant_vmr_clr): cfg_param_names = _cfg_param_base_names(cfg) if "log_10_H_over_H2" not in cfg_param_names: raise ValueError( "Atomic H is required by the configured CIA/special opacities, but parameter " "'log_10_H_over_H2' is missing. Add it to cfg.params to derive H from the " "H2+He filler (constant_vmr/constant_vmr_clr)." ) # For constant VMR: validate and build optimized kernel if chemistry_kernel is constant_vmr: validate_log10_vmr_params(cfg, trace_species) chemistry_kernel = constant_vmr(trace_species) elif chemistry_kernel is constant_vmr_clr: use_log10_vmr = validate_clr_params(cfg, trace_species) chemistry_kernel = constant_vmr_clr(trace_species, use_log10_vmr=use_log10_vmr) return chemistry_kernel, trace_species
[docs] def load_nasa9_if_needed(cfg: Any, exp_dir: Path) -> None: """Load NASA-9 thermo coefficients if RateJAX chemistry requires it. This function checks if the configured chemistry scheme requires Gibbs free energy data (RateJAX chemical equilibrium modes) and loads the NASA-9 tables if needed. If data is already loaded or not required, it returns immediately. Parameters ---------- cfg : config object Parsed YAML configuration object with `cfg.physics.vert_chem` attribute. exp_dir : `~pathlib.Path` Experiment directory used to resolve relative paths to NASA-9 data. """ phys = getattr(cfg, "physics", None) if phys is None: return vert_chem_raw = getattr(phys, "vert_chem", None) if vert_chem_raw is None: return vert_chem_name = str(vert_chem_raw).lower() ep_names = ("easychem_jax", "easychem") fc_grid_names = ( "fastchem_grid_jax", "ce_fastchem_grid", "fastchem_ce_grid", "ce", "chemical_equilibrium", "ce_fastchem_jax", "fastchem_jax", ) needs_nasa9 = ( "rate_ce", "rate_jax", "ce_rate_jax", *ep_names, ) if vert_chem_name in ("quench_approx", "quench"): init_quench_approx_if_needed(cfg, exp_dir) return if vert_chem_name in fc_grid_names: init_fastchem_grid_if_needed(cfg, exp_dir) return if vert_chem_name not in needs_nasa9: return from .rate_jax import is_nasa9_cache_loaded, load_nasa9_cache if is_nasa9_cache_loaded(): print("[info] NASA-9 cache already loaded") else: data_cfg = getattr(cfg, "data", None) nasa9_rel_path = getattr(data_cfg, "nasa9", None) if data_cfg is not None else None if nasa9_rel_path is None: raise ValueError( "NASA-9 data path not found in config. Please add 'nasa9: path/to/NASA9' " "under 'data:' section in YAML config." ) base_dir = Path.cwd() if exp_dir is None else Path(exp_dir) nasa9_path = ( str(base_dir / nasa9_rel_path) if not Path(nasa9_rel_path).is_absolute() else nasa9_rel_path ) print(f"[info] Loading NASA-9 thermo tables from {nasa9_path}") thermo = load_nasa9_cache(nasa9_path) print(f"[info] NASA-9 cache loaded: {len(thermo.data)} species") # EP backend needs its own model cache as well; initialize it here so callers # that only invoke `load_nasa9_if_needed` (e.g. bestfit scripts) still work. if vert_chem_name in ep_names: init_element_potentials_if_needed(cfg, exp_dir) return
[docs] def init_fastchem_grid_if_needed(cfg: Any, exp_dir: Path) -> None: """Initialise and cache the FastChem 5D grid chemistry backend if required.""" phys = getattr(cfg, "physics", None) if phys is None: return vert_chem_name = str(getattr(phys, "vert_chem", "") or "").lower() aliases = ( "fastchem_grid_jax", "ce_fastchem_grid", "fastchem_ce_grid", "ce", "chemical_equilibrium", "ce_fastchem_jax", "fastchem_jax", "quench_approx", "quench", ) if vert_chem_name not in aliases: return from .vert_chem import ( is_fastchem_grid_cache_loaded, load_fastchem_grid_cache, get_fastchem_grid_cache_info, ) if is_fastchem_grid_cache_loaded(): print("[info] FastChem-grid cache already loaded") return fc_cfg = getattr(cfg, "fastchem_grid_jax", None) if fc_cfg is None: raise ValueError( "fastchem_grid_jax config block is required when physics.vert_chem " "uses the FastChem-grid backend." ) grid_path_raw = getattr(fc_cfg, "grid_path", None) if not grid_path_raw: raise ValueError("fastchem_grid_jax.grid_path is required and must point to a 5D NPZ grid.") grid_path_obj = Path(str(grid_path_raw)) if grid_path_obj.is_absolute(): grid_path = grid_path_obj else: base_dir = Path.cwd() if exp_dir is None else Path(exp_dir) candidate_exp = (base_dir / grid_path_obj).resolve() candidate_repo = (Path.cwd() / grid_path_obj).resolve() if candidate_exp.exists(): grid_path = candidate_exp else: grid_path = candidate_repo if not grid_path.exists(): raise FileNotFoundError(f"FastChem grid file not found: {grid_path}") bounds_cfg = getattr(fc_cfg, "bounds", None) bounds_mode = str(getattr(bounds_cfg, "mode", "clip")).lower() if bounds_cfg is not None else "clip" if bounds_mode != "clip": raise ValueError("fastchem_grid_jax.bounds.mode currently supports only: clip") param_names = _cfg_param_base_names(cfg) for required in ("M_to_H", "C_to_O"): if required not in param_names: raise ValueError( f"fastchem_grid_jax requires parameter '{required}' in cfg.params " "or an explicit transit_1_5d variant with _joint, _east, or _west suffix." ) species_out = infer_active_opacity_species(cfg) if not species_out: raise ValueError( "Could not infer any active opacity species for fastchem_grid_jax. " "Check opac.line/opac.ray/opac.cia/opac.special configuration." ) solver_cfg = getattr(fc_cfg, "solver", None) solver_kwargs: dict[str, Any] = {} if solver_cfg is not None: mode = getattr(solver_cfg, "mode", None) if mode is not None: solver_kwargs["mode"] = str(mode) if "mode" not in solver_kwargs: solver_kwargs["mode"] = "vmap" species_map_override: dict[str, str] = {} map_cfg = getattr(fc_cfg, "species_map", None) if map_cfg is not None: if isinstance(map_cfg, dict): items = map_cfg.items() else: items = vars(map_cfg).items() for key, val in items: species_map_override[str(key)] = str(val) print("[info] Initializing CE scheme: fastchem_grid_jax") print(f"[info] Grid path: {grid_path}") print(f"[info] Output species source: active opacity species ({len(species_out)})") print(f"[info] Solver: mode={solver_kwargs['mode']}") print(f"[info] Bounds: mode=clip") load_fastchem_grid_cache( grid_path=str(grid_path), species_out=list(species_out), solver_cfg=solver_kwargs, species_map_override=species_map_override if species_map_override else None, ) info = get_fastchem_grid_cache_info() if info: shape = info.get("shape", ()) print( "[info] Grid axes: " f"T={shape[0]}, P={shape[1]}, M/H={shape[2]}, C/O={shape[3]}, species={shape[4]}" ) print( "[info] Ranges: " f"T=[{info['T_range'][0]:.1f}, {info['T_range'][1]:.1f}] K, " f"P=[{info['P_range'][0]:.3e}, {info['P_range'][1]:.3e}] bar, " f"M/H=[{info['MH_range'][0]:.2f}, {info['MH_range'][1]:.2f}] dex, " f"C/O=[{info['CO_range'][0]:.3f}, {info['CO_range'][1]:.3f}]" ) print(f"[info] Interp axes: {'log10(T,P,C/O)+dex(M/H)' if info.get('use_log_axes', False) else 'linear'}") mapped_species = tuple(info.get("species_out", ())) print(f"[info] Mapped species: {len(mapped_species)}") if mapped_species: print(f"[info] Species list: {list(mapped_species)}") if info.get("unmapped_species"): print(f"[warn] Unmapped species skipped: {list(info['unmapped_species'])}") need_bf, need_ff = _special_hminus_requirements(cfg) required_special: list[str] = [] if need_bf: required_special.append("H-") if need_ff: required_special.extend(("H", "e-")) if required_special: missing_special = [sp for sp in required_special if sp not in mapped_species] if missing_special: raise ValueError( "H- special opacity is enabled, but required species are not mapped " "from the FastChem CE grid. " f"Missing={missing_special}, mapped={list(mapped_species)}. " "Fix via fastchem_grid_jax.species_map and/or grid contents." ) print("[info] FastChem-grid cache loaded")
[docs] def init_element_potentials_if_needed(cfg: Any, exp_dir: Path) -> None: """Initialise and cache the element-potentials chemistry backend if required.""" phys = getattr(cfg, "physics", None) if phys is None: return vert_chem_name = str(getattr(phys, "vert_chem", "") or "").lower() if vert_chem_name not in ("easychem_jax", "easychem"): return from .vert_chem import is_element_potentials_cache_loaded, load_element_potentials_cache if is_element_potentials_cache_loaded(): print("[info] Element-potentials cache already loaded") return ep_cfg = getattr(cfg, "easychem_jax", None) if ep_cfg is None: raise ValueError( "easychem_jax config block is required when physics.vert_chem is " "'easychem_jax'." ) raw_species = list(getattr(ep_cfg, "species", None) or []) species_list = [] for sp in raw_species: if isinstance(sp, bool): # YAML 1.1 parses unquoted `NO`/`Yes` as booleans; recover the common NO case. species_list.append("NO" if sp is False else "YES") else: species_list.append(str(sp)) if not species_list: raise ValueError( "easychem_jax.species must be a non-empty list in the YAML config." ) elements_val = getattr(ep_cfg, "elements", None) elements = list(elements_val) if elements_val else None solver_cfg = getattr(ep_cfg, "solver", None) solver_kwargs = {} if solver_cfg is not None: for key in ("mode", "max_steps", "tol", "throw", "relax_limit", "prefer_chord"): val = getattr(solver_cfg, key, None) if val is not None: solver_kwargs[key] = val data_cfg = getattr(cfg, "data", None) nasa9_rel_path = getattr(data_cfg, "nasa9", None) if data_cfg is not None else None if nasa9_rel_path is None: raise ValueError( "NASA-9 data path not found in config. Please add 'nasa9: path/to/NASA9' " "under 'data:' section in YAML config." ) base_dir = Path.cwd() if exp_dir is None else Path(exp_dir) nasa9_path = ( str(base_dir / nasa9_rel_path) if not Path(nasa9_rel_path).is_absolute() else nasa9_rel_path ) p0_bar = float(getattr(ep_cfg, "p0_bar", 1.0)) e_ref = str(getattr(ep_cfg, "e_ref", "H")) nlay = int(getattr(phys, "nlay", 99)) print("[info] Initializing CE scheme: easychem_jax") print(f"[info] NASA9 path: {nasa9_path}") print(f"[info] Species count: {len(species_list)}") print(f"[info] Elements: {elements if elements is not None else 'auto (inferred)'}") print(f"[info] e_ref: {e_ref}, P0_bar: {p0_bar}") print( "[info] Solver: " f"mode={solver_kwargs.get('mode', 'scan')}, " f"max_steps={solver_kwargs.get('max_steps', 64)}, " f"tol={solver_kwargs.get('tol', 1.0e-11)}, " f"throw={solver_kwargs.get('throw', False)}, " f"relax_limit={solver_kwargs.get('relax_limit', 0.75)}" ) load_element_potentials_cache( species_list=species_list, elements=elements, nlay=nlay, solver_kwargs=solver_kwargs, nasa9_dir=nasa9_path, p0_bar=p0_bar, e_ref=e_ref, ) print("[info] Element-potentials cache loaded")
[docs] def init_atmodeller_if_needed(cfg: Any, exp_dir: Path) -> None: """Initialise the global atmodeller :class:`EquilibriumModel` if required. Reads ``atmodeller.species_network`` from the YAML config and builds the cached :class:`EquilibriumModel` instance. If the cache is already loaded, or the active chemistry scheme is not ``'atmodeller'``, returns immediately. Parameters ---------- cfg : config object Parsed YAML configuration object with ``cfg.physics.vert_chem`` and ``cfg.atmodeller.species_network`` attributes. exp_dir : `~pathlib.Path` Experiment directory (unused; kept for API symmetry with :func:`load_nasa9_if_needed`). """ phys = getattr(cfg, "physics", None) if phys is None: return vert_chem_name = str(getattr(phys, "vert_chem", "") or "").lower() if vert_chem_name != "atmodeller": return if importlib.util.find_spec("atmodeller") is None: raise ImportError( "physics.vert_chem is set to 'atmodeller', but the optional 'atmodeller' " "package is not installed." ) atm_cfg = getattr(cfg, "atmodeller", None) if atm_cfg in (False, "False", "false", "off", "none", "None"): print("[info] Atmodeller explicitly disabled in YAML; skipping atmodeller initialization.") return from .vert_chem import is_atmodeller_cache_loaded, load_atmodeller_cache if is_atmodeller_cache_loaded(): print("[info] Atmodeller cache already loaded") return species_list = list(getattr(atm_cfg, "species_network", None) or []) if not species_list: raise ValueError( "atmodeller.species_network must be a non-empty list in the YAML config. " "Example:\n atmodeller:\n species_network:\n - H2O_g\n - CO_g" ) # Parse optional solver tuning from atmodeller.solver (all fields optional) solver_cfg = getattr(atm_cfg, "solver", None) solver_kwargs = {} if solver_cfg is not None: for key in ("atol", "rtol", "max_steps", "multistart", "multistart_perturbation", "jac"): val = getattr(solver_cfg, key, None) if val is not None: solver_kwargs[key] = val nlay = int(cfg.physics.nlay) print(f"[info] Building atmodeller EquilibriumModel with {len(species_list)} species, nlay={nlay}") if solver_kwargs: print(f"[info] Atmodeller solver settings: {solver_kwargs}") load_atmodeller_cache(species_list, nlay, solver_kwargs=solver_kwargs) print("[info] Atmodeller cache loaded")
[docs] def init_quench_approx_if_needed(cfg: Any, exp_dir: Path) -> None: """Initialise the FastChem 5D grid and quench species list for quench_approx. Reads ``fastchem_grid_jax`` (grid path, solver, species map) and ``quench_approx.quench_species`` from the YAML config. Both caches are idempotent — already-loaded caches are skipped. Parameters ---------- cfg : config object Parsed YAML configuration object. exp_dir : `~pathlib.Path` Experiment directory for resolving relative paths. """ phys = getattr(cfg, "physics", None) if phys is None: return vert_chem_name = str(getattr(phys, "vert_chem", "") or "").lower() if vert_chem_name not in ("quench_approx", "quench"): return # Step 1: load the FastChem 5D grid (reuses init_fastchem_grid_if_needed, # which now recognises quench_approx/quench in its aliases). init_fastchem_grid_if_needed(cfg, exp_dir) # Step 2: load the quench species list. from .vert_chem import is_quench_approx_cache_loaded, load_quench_approx_cache if is_quench_approx_cache_loaded(): print("[info] Quench species already loaded") return qa_cfg = getattr(cfg, "quench_approx", None) if qa_cfg is None: raise ValueError( "quench_approx config block is required when physics.vert_chem is " "'quench_approx'. Add a 'quench_approx:' section with 'quench_species:' list." ) quench_species = list(getattr(qa_cfg, "quench_species", None) or []) if not quench_species: raise ValueError( "quench_approx.quench_species must be a non-empty list in the YAML config. " "Example:\n quench_approx:\n quench_species:\n - CO\n - CH4" ) param_names = _cfg_param_base_names(cfg) if "log_10_Kzz" not in param_names: raise ValueError( "quench_approx requires parameter 'log_10_Kzz' in cfg.params " "or an explicit transit_1_5d variant with _joint, _east, or _west suffix." ) if "log_10_g" not in param_names and "M_p" not in param_names: raise ValueError( "quench_approx requires either 'log_10_g' or 'M_p' (+ 'R_p') in cfg.params " "so that surface gravity is available for the mixing timescale. For transit_1_5d, " "use explicit _joint, _east, or _west variants." ) load_quench_approx_cache(quench_species) print(f"[info] Quench species loaded: {list(quench_species)}")