Source code for exo_skryer.help_runtime
"""
help_runtime.py
===============
"""
import os
__all__ = ['apply_runtime_env']
[docs]
def apply_runtime_env(rc):
"""
Apply optional rc.runtime settings *before* importing JAX / NumPyro.
Expected (all optional):
runtime:
platform: cpu | gpu | cuda | metal
cuda_visible_devices: "0" or "0,1"
preallocate: true/false
mem_fraction: 0.8
jax_threads: 4
blackjax_threads: 4
numpyro_threads: 4
numpyro_platform: cpu | gpu # optional override for NumPyro/JAX device
jax_x64: true/false
Returns
-------
numpyro_threads : int | None
Hint for numpyro.set_host_device_count in driver implementation.
"""
rt = getattr(rc, "runtime", None)
if rt is None:
return None
# --------------------
# JAX platform / device
# --------------------
plat = getattr(rt, "platform", None)
if plat:
plat_norm = str(plat).lower()
# Convenience alias
if plat_norm == "gpu":
try:
sysname = os.uname().sysname.lower()
except Exception:
sysname = ""
# Rough heuristic: macOS → metal, else cuda
if "darwin" in sysname or "mac" in sysname:
plat_norm = "metal"
else:
plat_norm = "cuda"
# JAX >= 0.4 prefers JAX_PLATFORMS; we also set JAX_PLATFORM_NAME for older code.
os.environ["JAX_PLATFORMS"] = plat_norm
if plat_norm == "cpu":
os.environ["JAX_PLATFORM_NAME"] = "cpu"
elif plat_norm in ("cuda", "metal", "rocm"):
# JAX treats these as "gpu" at the higher level
os.environ["JAX_PLATFORM_NAME"] = "gpu"
# Optional: NumPyro-specific override (still via JAX, provides additional
# control over "NumPyro device" conceptually).
np_plat = getattr(rt, "numpyro_platform", None)
if np_plat:
np_plat_norm = str(np_plat).lower()
if np_plat_norm in ("cpu", "gpu"):
os.environ["JAX_PLATFORM_NAME"] = np_plat_norm
# --------------------
# CUDA device visibility
# --------------------
cvis = getattr(rt, "cuda_visible_devices", None)
if cvis is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(cvis)
# --------------------
# XLA memory behaviour
# --------------------
prealloc = getattr(rt, "preallocate", None)
if prealloc is not None:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true" if prealloc else "false"
memfrac = getattr(rt, "mem_fraction", None)
if memfrac is not None:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(memfrac)
# --------------------
# Thread / host-device counts
# --------------------
jax_threads = getattr(rt, "jax_threads", None)
if jax_threads is not None:
os.environ["JAX_NUM_THREADS"] = str(int(jax_threads))
# For BlackJAX on CPU we sometimes want XLA to expose N "host devices"
blackjax_threads = getattr(rt, "blackjax_threads", None)
if blackjax_threads is not None:
os.environ["XLA_FLAGS"] = (
"--xla_force_host_platform_device_count="
+ str(int(blackjax_threads))
)
# --------------------
# NumPyro hint: how many host devices / chains to use
# --------------------
numpyro_threads = getattr(rt, "numpyro_threads", None)
return int(numpyro_threads) if numpyro_threads is not None else None