'''
aux_functions.py
================
'''
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax import lax
import functools
from typing import Optional, Literal
__all__ = ['pchip_1d', 'latin_hypercube', 'simpson', 'simpson_padded']
def _pchip_slopes(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
Compute slopes for PCHIP interpolation.
Parameters
----------
x : `~jax.numpy.ndarray`
1D array of node positions with shape (N,), must be sorted in ascending
order. Minimum length is 2.
y : `~jax.numpy.ndarray`
1D array of function values at the node positions with shape (N,).
Must have the same length as `x`.
Returns
-------
m : `~jax.numpy.ndarray`
1D array of slopes (first derivatives) at each node position with shape (N,).
"""
x = jnp.asarray(x)
y = jnp.asarray(y)
N = x.shape[0]
h = jnp.diff(x)
delta = jnp.diff(y) / h
def slopes_N2():
m = jnp.full_like(y, delta[0])
return m
def slopes_Nge3():
h0 = h[:-1]
h1 = h[1:]
d0 = delta[:-1]
d1 = delta[1:]
w1 = 2.0 * h1 + h0
w2 = 2.0 * h0 + h1
same_sign = (d0 * d1) > 0.0
denom = (w1 / jnp.where(jnp.abs(d0) > 0.0, d0, 1.0)
+ w2 / jnp.where(jnp.abs(d1) > 0.0, d1, 1.0))
m_inner = (w1 + w2) / denom
m_inner = jnp.where(same_sign, m_inner, 0.0)
m0 = ((2.0*h[0] + h[1]) * delta[0] - h[0] * delta[1]) / (h[0] + h[1])
m0 = jnp.where(jnp.sign(m0) != jnp.sign(delta[0]), 0.0, m0)
m0 = jnp.where(
(jnp.sign(delta[0]) != jnp.sign(delta[1])) & (jnp.abs(m0) > jnp.abs(3.0 * delta[0])),
3.0 * delta[0],
m0,
)
mN = ((2.0*h[-1] + h[-2]) * delta[-1] - h[-1] * delta[-2]) / (h[-1] + h[-2])
mN = jnp.where(jnp.sign(mN) != jnp.sign(delta[-1]), 0.0, mN)
mN = jnp.where(
(jnp.sign(delta[-1]) != jnp.sign(delta[-2])) & (jnp.abs(mN) > jnp.abs(3.0 * delta[-1])),
3.0 * delta[-1],
mN,
)
m = jnp.empty_like(y)
m = m.at[0].set(m0)
m = m.at[1:-1].set(m_inner)
m = m.at[-1].set(mN)
return m
return jnp.where(N == 2, slopes_N2(), slopes_Nge3())
[docs]
def pchip_1d(x: jnp.ndarray,
x_nodes: jnp.ndarray,
y_nodes: jnp.ndarray) -> jnp.ndarray:
"""Piecewise Cubic Hermite Interpolating Polynomial (PCHIP).
Provides a 1D monotonic cubic interpolation. Values outside the node
range are clipped to the boundary values.
Parameters
----------
x : `~jax.numpy.ndarray`
The x-coordinates at which to evaluate the interpolated values.
Can be any shape; interpolation is performed element-wise.
x_nodes : `~jax.numpy.ndarray`
1D array of data point x-coordinates, must be sorted in ascending order.
Minimum length is 2.
y_nodes : `~jax.numpy.ndarray`
1D array of data point y-coordinates corresponding to `x_nodes`.
Must have the same length as `x_nodes`.
Returns
-------
y : `~jax.numpy.ndarray`
The interpolated values at positions `x`, with the same shape as `x`.
Values are computed using shape-preserving cubic Hermite interpolation.
Points outside the range [x_nodes[0], x_nodes[-1]] are clipped to
boundary values.
"""
x = jnp.asarray(x)
x_nodes = jnp.asarray(x_nodes)
y_nodes = jnp.asarray(y_nodes)
x_min = x_nodes[0]
x_max = x_nodes[-1]
x_eval = jnp.clip(x, x_min, x_max)
m_nodes = _pchip_slopes(x_nodes, y_nodes) # (N,)
idx = jnp.searchsorted(x_nodes, x_eval, side="right") - 1
nseg = x_nodes.shape[0] - 1
idx = jnp.clip(idx, 0, nseg - 1)
x0 = x_nodes[idx]
x1 = x_nodes[idx + 1]
y0 = y_nodes[idx]
y1 = y_nodes[idx + 1]
m0 = m_nodes[idx]
m1 = m_nodes[idx + 1]
h = x1 - x0
t = (x_eval - x0) / jnp.maximum(h, 1e-30)
h00 = 2.0 * t**3 - 3.0 * t**2 + 1.0
h10 = t**3 - 2.0 * t**2 + t
h01 = -2.0 * t**3 + 3.0 * t**2
h11 = t**3 - t**2
y = h00 * y0 + h10 * h * m0 + h01 * y1 + h11 * h * m1
return y
[docs]
def latin_hypercube(
key: jax.Array,
n_samples: int,
n_dim: int,
*,
scramble: bool = True,
dtype=jnp.float64,
) -> tuple[jnp.ndarray, jax.Array]:
"""Generate Latin hypercube samples in the unit hypercube [0, 1)^n_dim.
Parameters
----------
key : `~jax.Array`
JAX PRNG key for random number generation.
n_samples : int
Number of samples to generate. Must be positive.
n_dim : int
Number of dimensions for each sample. Must be positive.
scramble : bool, optional
If True (default), randomly permutes the stratum assignments for each
dimension independently, reducing correlation between dimensions and
improving space-filling properties. If False, strata are assigned
sequentially without permutation.
dtype : dtype, optional
Data type for the output array. Default is `jax.numpy.float64`.
Returns
-------
samples : `~jax.numpy.ndarray`
Generated Latin hypercube samples with shape `(n_samples, n_dim)`.
Each value is in the range [0, 1).
key : `~jax.Array`
Updated PRNG key for subsequent random operations.
"""
dtype = jnp.dtype(dtype)
key, key_u, key_perm = jax.random.split(key, 3)
base = (
jnp.arange(n_samples, dtype=dtype)[:, None]
+ jax.random.uniform(key_u, (n_samples, n_dim), dtype=dtype)
) / jnp.asarray(n_samples, dtype=dtype)
if not scramble:
return base, key
perm_keys = jax.random.split(key_perm, n_dim)
def _permute_one(col: jnp.ndarray, k: jax.Array) -> jnp.ndarray:
perm = jax.random.permutation(k, n_samples)
return col[perm]
cols = jax.vmap(_permute_one, in_axes=(0, 0))(base.T, perm_keys) # (n_dim, n_samples)
return cols.T, key
EvenMode = Optional[Literal["simpson", "avg", "first", "last"]]
def _as_last_axis(a: jnp.ndarray, axis: int) -> jnp.ndarray:
return jnp.moveaxis(a, axis, -1)
def _trapz_last(y: jnp.ndarray, x: Optional[jnp.ndarray], dx: float) -> jnp.ndarray:
# y shape (..., N>=2)
y0 = y[..., -2]
y1 = y[..., -1]
if x is None:
h = jnp.asarray(dx, dtype=y.dtype)
else:
h = x[..., -1] - x[..., -2]
return 0.5 * h * (y0 + y1)
def _trapz_first(y: jnp.ndarray, x: Optional[jnp.ndarray], dx: float) -> jnp.ndarray:
# y shape (..., N>=2)
y0 = y[..., 0]
y1 = y[..., 1]
if x is None:
h = jnp.asarray(dx, dtype=y.dtype)
else:
h = x[..., 1] - x[..., 0]
return 0.5 * h * (y0 + y1)
def _simpson_odd_uniform(y: jnp.ndarray, dx: float) -> jnp.ndarray:
# y shape (..., N) with N odd >= 3
n = y.shape[-1]
y0 = y[..., 0]
yN = y[..., -1]
odd_sum = jnp.sum(y[..., 1:n-1:2], axis=-1)
even_sum = jnp.sum(y[..., 2:n-1:2], axis=-1)
return (jnp.asarray(dx, dtype=y.dtype) / 3.0) * (y0 + yN + 4.0 * odd_sum + 2.0 * even_sum)
def _simpson_odd_unequal(y: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
# y, x shape (..., N) with N odd >= 3
# Composite Simpson for irregular spacing:
# sum over pairs of intervals with widths h0, h1:
# (h0+h1)/6 * [ (2 - h1/h0) y0 + ((h0+h1)^2/(h0*h1)) y1 + (2 - h0/h1) y2 ]
#
# IMPORTANT: When h0 or h1 is zero (repeated x-values from padding),
# that interval contributes nothing to the integral.
h = jnp.diff(x, axis=-1) # (..., N-1)
h0 = h[..., 0::2] # (..., (N-1)/2)
h1 = h[..., 1::2] # (..., (N-1)/2)
y0 = y[..., 0:-2:2]
y1 = y[..., 1:-1:2]
y2 = y[..., 2::2]
hsum = h0 + h1
# Use a small epsilon to detect zero spacing (from repeated x-values)
eps = 1e-15
h0_is_zero = jnp.abs(h0) < eps
h1_is_zero = jnp.abs(h1) < eps
# When either h0 or h1 is zero, the interval contributes 0
# Use jnp.where to avoid division by zero
safe_h0 = jnp.where(h0_is_zero, 1.0, h0) # Replace 0 with 1 to avoid division by zero
safe_h1 = jnp.where(h1_is_zero, 1.0, h1)
term0 = (2.0 - (h1 / safe_h0)) * y0
term1 = ((hsum * hsum) / (safe_h0 * safe_h1)) * y1
term2 = (2.0 - (h0 / safe_h1)) * y2
# Zero out contributions from intervals with zero spacing
interval_valid = ~(h0_is_zero | h1_is_zero)
contribution = (hsum / 6.0) * (term0 + term1 + term2)
contribution = jnp.where(interval_valid, contribution, 0.0)
return jnp.sum(contribution, axis=-1)
def _simpson_odd(y: jnp.ndarray, x: Optional[jnp.ndarray], dx: float) -> jnp.ndarray:
if x is None:
return _simpson_odd_uniform(y, dx)
return _simpson_odd_unequal(y, x)
def _simpson_even_cartwright_last_interval(y: jnp.ndarray, x: Optional[jnp.ndarray], dx: float) -> jnp.ndarray:
# "simpson" behaviour for even N: Simpson on first N-1 points + special last-interval correction
# Uses last three points. For uniform spacing this reduces to:
# dx * (5/12*y[-1] + 2/3*y[-2] - 1/12*y[-3])
if x is None:
h = jnp.asarray(dx, dtype=y.dtype)
return h * ((5.0/12.0) * y[..., -1] + (2.0/3.0) * y[..., -2] - (1.0/12.0) * y[..., -3])
h0 = x[..., -2] - x[..., -3]
h1 = x[..., -1] - x[..., -2]
# Handle zero spacing (repeated x-values from padding)
eps = 1e-15
h0_is_zero = jnp.abs(h0) < eps
h1_is_zero = jnp.abs(h1) < eps
# If either spacing is zero, fall back to trapezoid rule on the valid portion
# or return zero if both are zero
safe_h0 = jnp.where(h0_is_zero, 1.0, h0)
safe_h1 = jnp.where(h1_is_zero, 1.0, h1)
alpha = (2.0 * safe_h1 * safe_h1 + 3.0 * safe_h0 * safe_h1) / (6.0 * (safe_h0 + safe_h1))
beta = (safe_h1 * safe_h1 + 3.0 * safe_h0 * safe_h1) / (6.0 * safe_h0)
eta = (safe_h1 * safe_h1 * safe_h1) / (6.0 * safe_h0 * (safe_h0 + safe_h1))
result = alpha * y[..., -1] + beta * y[..., -2] - eta * y[..., -3]
# If the last interval has zero spacing, the contribution should be zero
result = jnp.where(h1_is_zero, 0.0, result)
return result
[docs]
@functools.partial(jax.jit, static_argnames=("axis", "even"))
def simpson(
y,
*,
x: Optional[jnp.ndarray] = None,
dx: float = 1.0,
axis: int = -1,
even: EvenMode = None,
):
"""
JAX-compatible composite Simpson integrator, similar to scipy.integrate.simpson.
Parameters
----------
y : array_like
Values to integrate.
x : array_like, optional
Sample points. If 1D, must have length y.shape[axis]. If broadcastable, must match y.
dx : float
Spacing used when x is None.
axis : int
Axis of integration.
even : {None, 'simpson', 'avg', 'first', 'last'}
Handling when number of samples is even. Matches SciPy's documented behaviours.
"""
y = jnp.asarray(y)
y = _as_last_axis(y, axis)
n = y.shape[-1]
# Prepare x in "last-axis" layout, broadcasted to y if provided.
if x is not None:
x = jnp.asarray(x)
if x.ndim == 1:
# broadcast to y's leading dims
x = jnp.broadcast_to(x, y.shape)
else:
x = _as_last_axis(x, axis)
x = jnp.broadcast_to(x, y.shape)
# Degenerate cases
if n == 0:
return jnp.zeros(y.shape[:-1], dtype=y.dtype)
if n == 1:
return jnp.zeros(y.shape[:-1], dtype=y.dtype)
if n == 2:
# Simpson not possible; fall back to trapezoid
return _trapz_last(y, x, dx)
# Odd number of samples: standard composite Simpson
if (n % 2) == 1:
return _simpson_odd(y, x, dx)
# Even number of samples: choose strategy
mode = "simpson" if even is None else even
if mode == "first":
# Simpson over first N-1 points + trapezoid on last interval
base = _simpson_odd(y[..., :-1], None if x is None else x[..., :-1], dx)
return base + _trapz_last(y, x, dx)
if mode == "last":
# trapezoid on first interval + Simpson over last N-1 points
base = _simpson_odd(y[..., 1:], None if x is None else x[..., 1:], dx)
return _trapz_first(y, x, dx) + base
if mode == "avg":
# average of 'first' and 'last'
first = (
_simpson_odd(y[..., :-1], None if x is None else x[..., :-1], dx) +
_trapz_last(y, x, dx)
)
last = (
_trapz_first(y, x, dx) +
_simpson_odd(y[..., 1:], None if x is None else x[..., 1:], dx)
)
return 0.5 * (first + last)
# mode == "simpson": Simpson over first N-1 points + Cartwright correction
base = _simpson_odd(y[..., :-1], None if x is None else x[..., :-1], dx)
corr = _simpson_even_cartwright_last_interval(y, x, dx)
return base + corr
def _move_last(a: jnp.ndarray, axis: int) -> jnp.ndarray:
"""Move specified axis to last position."""
return jnp.moveaxis(a, axis, -1)
def _take_last_axis(a: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray:
"""Take element from last axis at index idx. vmap-friendly."""
return jnp.take(a, idx, axis=-1, mode="clip")
[docs]
@functools.partial(jax.jit, static_argnames=("axis", "even"))
def simpson_padded(
y: jnp.ndarray,
x: jnp.ndarray,
n_valid: jnp.ndarray,
*,
axis: int = -1,
even: str = "first",
):
"""
Composite Simpson integration for padded arrays with non-uniform spacing.
Parameters
----------
y : `~jax.numpy.ndarray`
Function values, padded to Nmax along `axis`.
x : `~jax.numpy.ndarray`
Sample points (same shape as y), padded to Nmax along `axis`.
Must be strictly increasing on the valid prefix [0:n_valid].
n_valid : int or `~jax.numpy.ndarray`
Number of valid (unpadded) points along `axis` (>= 2).
Can be scalar int or per-batch int array (vmap-friendly).
axis : int, optional
Axis of integration. Default: -1
even : {'first', 'last', 'avg'}, optional
Handling when n_valid is even:
- 'first': Simpson on first n_valid-1 + trapezoid on last interval
- 'last': Trapezoid on first + Simpson on last n_valid-1
- 'avg': Average of 'first' and 'last'
Default: 'first'
Returns
-------
integral : `~jax.numpy.ndarray`
"""
y = jnp.asarray(y)
x = jnp.asarray(x)
y = _move_last(y, axis)
x = _move_last(x, axis)
Nmax = y.shape[-1]
n = jnp.clip(n_valid, 2, Nmax)
# ---- Grid sanity: require strictly increasing x on valid prefix ----
# Dynamic slicing not allowed under jit, so mask the diffs.
h_all = jnp.diff(x, axis=-1) # (..., Nmax-1)
idx_h = jnp.arange(Nmax - 1)
h_mask = idx_h < (n - 1) # True for valid diffs only
grid_ok = jnp.all(jnp.where(h_mask, h_all > 0.0, True), axis=-1)
# Simpson needs odd number of points; for even n apply Simpson on first n-1
m = jnp.where((n % 2) == 1, n, n - 1) # odd, >= 3 unless n == 2
def core_simpson(_):
# Handle m == 2 (happens if n == 2): trapezoid on that single interval.
def trap_only(_):
i1 = n - 1
i0 = n - 2
h = _take_last_axis(x, i1) - _take_last_axis(x, i0)
return 0.5 * h * (_take_last_axis(y, i0) + _take_last_axis(y, i1))
def simpson_body(_):
# Pair indices for Simpson panels (y0, y1, y2) up to max possible for Nmax
n_pairs_max = (Nmax - 1) // 2
j = jnp.arange(n_pairs_max)
i0 = 2 * j
i1 = i0 + 1
i2 = i0 + 2
# A pair is valid iff i2 < m (all three points inside Simpson prefix)
pair_mask = i2 < m
# Gather y triplets (static-length slices, JIT-friendly)
y0 = y[..., 0:-2:2]
y1 = y[..., 1:-1:2]
y2 = y[..., 2::2]
# Interval widths for each Simpson panel
h = jnp.diff(x, axis=-1) # (..., Nmax-1)
h0 = h[..., 0:-1:2] # (..., n_pairs_max)
h1 = h[..., 1::2] # (..., n_pairs_max)
# Also require positive widths inside used region
pair_mask = pair_mask & (h0 > 0.0) & (h1 > 0.0)
hsum = h0 + h1
# Safe ratios (only used where pair_mask True)
r10 = jnp.where(pair_mask, h1 / h0, 0.0)
r01 = jnp.where(pair_mask, h0 / h1, 0.0)
mid = jnp.where(pair_mask, (hsum * hsum) / (h0 * h1), 0.0)
term0 = (2.0 - r10) * y0
term1 = mid * y1
term2 = (2.0 - r01) * y2
panel = (hsum / 6.0) * (term0 + term1 + term2)
panel = jnp.where(pair_mask, panel, 0.0)
simpson_part = jnp.sum(panel, axis=-1)
# Even-n handling
def add_last_trap(_):
i1 = n - 1
i0 = n - 2
hlast = _take_last_axis(x, i1) - _take_last_axis(x, i0)
trap = 0.5 * hlast * (_take_last_axis(y, i0) + _take_last_axis(y, i1))
return simpson_part + trap
def add_first_trap(_):
i0 = jnp.array(0, dtype=jnp.int32)
i1 = jnp.array(1, dtype=jnp.int32)
hfirst = _take_last_axis(x, i1) - _take_last_axis(x, i0)
trap = 0.5 * hfirst * (_take_last_axis(y, i0) + _take_last_axis(y, i1))
return trap + simpson_part
def even_adjust(_):
if even == "last":
return add_first_trap(None)
if even == "avg":
return 0.5 * (add_last_trap(None) + add_first_trap(None))
# default "first": Simpson on first n-1 + last trap
return add_last_trap(None)
return lax.cond((n % 2) == 0, even_adjust, lambda _: simpson_part, operand=None)
return lax.cond(m <= 2, trap_only, simpson_body, operand=None)
# If grid is bad (e.g., repeated x inside valid region), return NaN
def bad_grid(_):
return jnp.nan * jnp.ones(y.shape[:-1], dtype=y.dtype)
return lax.cond(grid_ok, core_simpson, bad_grid, operand=None)