JAX Numerical Tools Reference
The hicosmo.utils.jax_tools module provides a complete set of JAX-optimized numerical tools, including integration, differentiation, root-finding, and ODE solvers.
All functions support JIT compilation and automatic differentiation.
Quick Start
from hicosmo.utils import (
integrate_simpson, # Simpson integration
gauss_legendre_integrate, # Gauss-Legendre integration
cumulative_trapezoid, # Cumulative trapezoidal integration
gradient_1d, # 1D gradient
newton_root, # Newton's method root-finding
odeint_rk4, # RK4 ODE solver
)
Integration Tools
Trapezoidal Integration
- trapezoid(y, x, axis=-1)
JAX trapezoidal integration wrapper.
- Parameters:
y – Array of function values
x – Array of independent variable values
axis – Integration axis
- Returns:
Integration result
Example:
import jax.numpy as jnp from hicosmo.utils import trapezoid x = jnp.linspace(0, 1, 100) y = x**2 result = trapezoid(y, x) # ≈ 1/3
Simpson Integration
- simpson(y, x)
Composite Simpson integration for discretely sampled data.
- Parameters:
y – Array of function values (at least 3 points)
x – Array of independent variable values
- Returns:
Integration result
- Raises:
ValueError – If fewer than 3 points are provided
- integrate_simpson(func, a, b, *, num=512)
Integrate a function over the interval [a, b] using Simpson’s rule.
- Parameters:
func – Integrand f(x)
a – Lower integration limit
b – Upper integration limit
num – Number of sampling points (default 512, automatically adjusted to even)
- Returns:
Integration result
Example:
from hicosmo.utils import integrate_simpson import jax.numpy as jnp # Compute integral of x^2 from 0 to 1 = 1/3 result = integrate_simpson(lambda x: x**2, 0, 1) print(f"Result: {result:.6f}") # ≈ 0.333333
Log-Space Integration
- integrate_logspace(func, k_min, k_max, *, num=512)
Integrate on a logarithmic-space grid, suitable for power-law functions with positive domains.
- Parameters:
func – Integrand f(k)
k_min – Lower integration limit (must be > 0)
k_max – Upper integration limit
num – Number of sampling points
- Returns:
Integration result
Example:
from hicosmo.utils import integrate_logspace # Integrate a power-law function result = integrate_logspace(lambda k: k**(-2), 0.01, 100)
Gauss-Legendre Integration
High-precision quadrature method, highly effective for smooth functions.
- gauss_legendre_nodes_weights(order)
Return Gauss-Legendre nodes and weights for the specified order.
- Parameters:
order – Order (supports 8, 12, 16)
- Returns:
(nodes, weights) tuple
- gauss_legendre_integrate(integrand, a, b, *, order=12)
Perform Gauss-Legendre integration over the interval [a, b].
- Parameters:
integrand – Integrand function
a – Lower integration limit
b – Upper integration limit
order – Integration order (8, 12, or 16)
- Returns:
Integration result
- gauss_legendre_integrate_batch(integrand, z_array, *, z_min=0.0, order=12)
Batch Gauss-Legendre integration, computing simultaneously for multiple upper limits.
- Parameters:
integrand – Integrand function
z_array – Array of upper limits
z_min – Lower integration limit
order – Integration order
- Returns:
Array of integration results
Example:
from hicosmo.utils import gauss_legendre_integrate_batch import jax.numpy as jnp # Compute integrals for multiple upper limits z_values = jnp.array([0.5, 1.0, 2.0]) results = gauss_legendre_integrate_batch( lambda z: 1.0 / jnp.sqrt(0.3 * (1 + z)**3 + 0.7), z_values, z_min=0.0, order=12 )
Cumulative Integration
- cumulative_trapezoid(y, x)
Cumulative trapezoidal integration, returning cumulative integral values starting from the first point.
- Parameters:
y – Array of function values
x – Array of independent variable values
- Returns:
Cumulative integral array (first element is 0)
Example:
from hicosmo.utils import cumulative_trapezoid import jax.numpy as jnp x = jnp.linspace(0, 1, 100) y = jnp.ones_like(x) # Constant function cumulative = cumulative_trapezoid(y, x) # cumulative[-1] ≈ 1.0
- integrate_batch_cumulative(func, z_array, *, z_min=0.0, z_max=None, n_grid=4096)
Batch integration using a cumulative grid and interpolation. Suitable for a large number of query points.
- Parameters:
func – Integrand function
z_array – Array of query points
z_min – Lower integration limit
z_max – Upper grid limit (defaults to the maximum of z_array)
n_grid – Number of grid points
- Returns:
Array of integration results
Performance tip: When there are many query points, this method is faster than
gauss_legendre_integrate_batch.
Segmented Integration
- integrate_segmented(func, a, b, *, n_segments=8, order=12)
Segmented Gauss-Legendre integration, suitable for large integration intervals.
- Parameters:
func – Integrand function
a – Lower integration limit
b – Upper integration limit
n_segments – Number of segments
order – Gauss-Legendre order per segment
- Returns:
Integration result
Use case: When the integration interval is very large (e.g., [0, 1000]), a single Gauss-Legendre integration may lack precision; segmented integration improves accuracy.
Adaptive Simpson Integration
- integrate_adaptive_simpson(func, a, b, *, num=64, max_num=4096, tol=1e-6)
Adaptive Simpson integration, automatically adjusting grid density.
- Parameters:
func – Integrand function
a – Lower integration limit
b – Upper integration limit
num – Initial number of grid points (must be even)
max_num – Maximum number of grid points (must be an integer multiple of num)
tol – Convergence tolerance
- Returns:
Integration result
Differentiation Tools
1D Gradient
- gradient_1d(values, coords)
Compute the derivative on a 1D grid, supporting non-uniform grids.
- Parameters:
values – Array of function values
coords – Array of coordinates
- Returns:
Array of derivatives
Features: - Endpoints use 3-point Lagrange interpolation (second-order accuracy) - Interior points use second-order central differences - Supports non-uniform grids
Example:
from hicosmo.utils import gradient_1d import jax.numpy as jnp x = jnp.linspace(0, 2 * jnp.pi, 100) y = jnp.sin(x) dy_dx = gradient_1d(y, x) # ≈ cos(x)
Finite Difference Gradient
- finite_difference_grad(func, x, *, step=1e-5)
Compute the gradient of a scalar function using central finite differences.
- Parameters:
func – Scalar function f(x) -> scalar
x – Evaluation point (array)
step – Difference step size
- Returns:
Gradient array
Note: For differentiable functions, using
jax.gradfor automatic differentiation is recommended.
Automatic Differentiation Wrappers
- grad(func)
Return the gradient function of a function (
jax.gradwrapper).
- jacobian(func)
Return the Jacobian matrix function of a function (
jax.jacobianwrapper).
- hessian(func)
Return the Hessian matrix function of a function (
jax.hessianwrapper).Example:
from hicosmo.utils import grad, hessian import jax.numpy as jnp def f(x): return jnp.sum(x**2) grad_f = grad(f) hess_f = hessian(f) x = jnp.array([1.0, 2.0]) print(grad_f(x)) # [2., 4.]
Root-Finding Tools
Newton-Raphson Method
- newton_root(func, x0, *, tol=1e-8, max_iter=64, deriv=None)
Solve f(x) = 0 using the Newton-Raphson method.
- Parameters:
func – Target function f(x)
x0 – Initial guess
tol – Convergence tolerance
max_iter – Maximum number of iterations
deriv – Derivative function (optional, defaults to automatic differentiation)
- Returns:
Approximate root
Example:
from hicosmo.utils import newton_root import jax.numpy as jnp # Solve x^2 - 2 = 0 root = newton_root(lambda x: x**2 - 2, jnp.array(1.5)) print(f"sqrt(2) ≈ {root:.10f}") # 1.4142135624
Bisection Method
- bisection_root(func, a, b, *, tol=1e-8, max_iter=128)
Solve f(x) = 0 using the bisection method (requires opposite signs at interval endpoints).
- Parameters:
func – Target function f(x)
a – Left endpoint of the interval
b – Right endpoint of the interval
tol – Convergence tolerance
max_iter – Maximum number of iterations
- Returns:
Approximate root (returns NaN if endpoints have the same sign)
Example:
from hicosmo.utils import bisection_root import jax.numpy as jnp # Solve x^3 - x - 1 = 0 root = bisection_root( lambda x: x**3 - x - 1, jnp.array(1.0), jnp.array(2.0) )
ODE Solvers
RK4 Single Step
- rk4_step(func, t, y, dt)
Perform one step of fourth-order Runge-Kutta integration.
- Parameters:
func – ODE right-hand side f(t, y)
t – Current time
y – Current state
dt – Time step
- Returns:
State at the next time step
RK4 Integrator
- odeint_rk4(func, y0, t_grid)
Solve ODE y’ = f(t, y) on a fixed grid using the RK4 method.
- Parameters:
func – ODE right-hand side f(t, y)
y0 – Initial condition
t_grid – Time grid
- Returns:
Solution array (one row per time point)
Example:
from hicosmo.utils import odeint_rk4 import jax.numpy as jnp # Solve y' = -y, y(0) = 1 (analytical solution: y = e^(-t)) def rhs(t, y): return -y t = jnp.linspace(0, 5, 100) y = odeint_rk4(rhs, jnp.array(1.0), t)
- solve_ivp_rk4(func, y0, t0, t1, *, n_steps=256)
Solve an initial value problem on the interval [t0, t1] using fixed-step RK4.
- Parameters:
func – ODE right-hand side f(t, y)
y0 – Initial condition
t0 – Start time
t1 – End time
n_steps – Number of steps
- Returns:
(t_grid, y_grid) tuple
Cosmology-Specific Tools
Sound Horizon Computation
Based on the Eisenstein & Hu (1998) fitting formula.
- sound_horizon_drag_eh98(H0, Omega_m, Omega_b, T_cmb)
Compute the sound horizon at the drag epoch r_d.
- Parameters:
H0 – Hubble constant [km/s/Mpc]
Omega_m – Matter density parameter
Omega_b – Baryon density parameter
T_cmb – CMB temperature [K]
- Returns:
Sound horizon r_d [Mpc]
Example:
from hicosmo.utils.jax_tools import sound_horizon_drag_eh98 r_d = sound_horizon_drag_eh98( H0=67.36, Omega_m=0.3153, Omega_b=0.0493, T_cmb=2.7255 ) print(f"r_d = {r_d:.2f} Mpc") # ≈ 147 Mpc
- sound_horizon_eh98(H0, Omega_m, Omega_b, T_cmb, z)
Compute the sound horizon at an arbitrary redshift z.
- Parameters:
H0 – Hubble constant [km/s/Mpc]
Omega_m – Matter density parameter
Omega_b – Baryon density parameter
T_cmb – CMB temperature [K]
z – Redshift
- Returns:
Sound horizon r_s(z) [Mpc]
Performance Comparison
Applicable scenarios for each integration method:
Method |
Accuracy |
Use Case |
|---|---|---|
|
O(h²) |
Quick rough estimates |
|
O(h⁴) |
Standard choice for smooth functions |
|
High |
High-precision integration over small intervals |
|
O(h²) |
Batch integration with many query points |
|
High |
High-precision integration over large intervals |
JIT Compilation Notes
All functions are compatible with JAX JIT compilation:
from jax import jit
from hicosmo.utils import integrate_simpson
@jit
def my_function(params):
result = integrate_simpson(
lambda x: params[0] * x**2 + params[1],
0, 1
)
return result
Notes:
Integration limits [a, b] can be dynamic values
Grid point counts (num, n_grid) must be static values
Use
static_argnumsfor parameters that need to be static
Complete Example
Computing Comoving Distance
import jax.numpy as jnp
from hicosmo.utils import integrate_batch_cumulative
# Cosmological parameters
H0 = 70.0 # km/s/Mpc
Omega_m = 0.3
c = 299792.458 # km/s
def E_z(z):
"""Hubble parameter E(z) = H(z)/H0"""
return jnp.sqrt(Omega_m * (1 + z)**3 + (1 - Omega_m))
# Compute comoving distance for multiple redshifts
z_array = jnp.array([0.1, 0.5, 1.0, 2.0])
integrals = integrate_batch_cumulative(
lambda z: 1.0 / E_z(z),
z_array,
n_grid=4096
)
d_c = (c / H0) * integrals
for z, d in zip(z_array, d_c):
print(f"z = {z:.1f}: d_c = {d:.1f} Mpc")
References
Eisenstein, D. J., & Hu, W. (1998). Baryonic Features in the Matter Transfer Function. ApJ, 496, 605.