JAX Numerical Tools Reference

The hicosmo.utils.jax_tools module provides a complete set of JAX-optimized numerical tools, including integration, differentiation, root-finding, and ODE solvers. All functions support JIT compilation and automatic differentiation.

Quick Start

from hicosmo.utils import (
    integrate_simpson,      # Simpson integration
    gauss_legendre_integrate,  # Gauss-Legendre integration
    cumulative_trapezoid,   # Cumulative trapezoidal integration
    gradient_1d,            # 1D gradient
    newton_root,            # Newton's method root-finding
    odeint_rk4,             # RK4 ODE solver
)

Integration Tools

Trapezoidal Integration

trapezoid(y, x, axis=-1)

JAX trapezoidal integration wrapper.

Parameters:
  • y – Array of function values

  • x – Array of independent variable values

  • axis – Integration axis

Returns:

Integration result

Example:

import jax.numpy as jnp
from hicosmo.utils import trapezoid

x = jnp.linspace(0, 1, 100)
y = x**2
result = trapezoid(y, x)  # ≈ 1/3

Simpson Integration

simpson(y, x)

Composite Simpson integration for discretely sampled data.

Parameters:
  • y – Array of function values (at least 3 points)

  • x – Array of independent variable values

Returns:

Integration result

Raises:

ValueError – If fewer than 3 points are provided

integrate_simpson(func, a, b, *, num=512)

Integrate a function over the interval [a, b] using Simpson’s rule.

Parameters:
  • func – Integrand f(x)

  • a – Lower integration limit

  • b – Upper integration limit

  • num – Number of sampling points (default 512, automatically adjusted to even)

Returns:

Integration result

Example:

from hicosmo.utils import integrate_simpson
import jax.numpy as jnp

# Compute integral of x^2 from 0 to 1 = 1/3
result = integrate_simpson(lambda x: x**2, 0, 1)
print(f"Result: {result:.6f}")  # ≈ 0.333333

Log-Space Integration

integrate_logspace(func, k_min, k_max, *, num=512)

Integrate on a logarithmic-space grid, suitable for power-law functions with positive domains.

Parameters:
  • func – Integrand f(k)

  • k_min – Lower integration limit (must be > 0)

  • k_max – Upper integration limit

  • num – Number of sampling points

Returns:

Integration result

Example:

from hicosmo.utils import integrate_logspace

# Integrate a power-law function
result = integrate_logspace(lambda k: k**(-2), 0.01, 100)

Gauss-Legendre Integration

High-precision quadrature method, highly effective for smooth functions.

gauss_legendre_nodes_weights(order)

Return Gauss-Legendre nodes and weights for the specified order.

Parameters:

order – Order (supports 8, 12, 16)

Returns:

(nodes, weights) tuple

gauss_legendre_integrate(integrand, a, b, *, order=12)

Perform Gauss-Legendre integration over the interval [a, b].

Parameters:
  • integrand – Integrand function

  • a – Lower integration limit

  • b – Upper integration limit

  • order – Integration order (8, 12, or 16)

Returns:

Integration result

gauss_legendre_integrate_batch(integrand, z_array, *, z_min=0.0, order=12)

Batch Gauss-Legendre integration, computing simultaneously for multiple upper limits.

Parameters:
  • integrand – Integrand function

  • z_array – Array of upper limits

  • z_min – Lower integration limit

  • order – Integration order

Returns:

Array of integration results

Example:

from hicosmo.utils import gauss_legendre_integrate_batch
import jax.numpy as jnp

# Compute integrals for multiple upper limits
z_values = jnp.array([0.5, 1.0, 2.0])
results = gauss_legendre_integrate_batch(
    lambda z: 1.0 / jnp.sqrt(0.3 * (1 + z)**3 + 0.7),
    z_values,
    z_min=0.0,
    order=12
)

Cumulative Integration

cumulative_trapezoid(y, x)

Cumulative trapezoidal integration, returning cumulative integral values starting from the first point.

Parameters:
  • y – Array of function values

  • x – Array of independent variable values

Returns:

Cumulative integral array (first element is 0)

Example:

from hicosmo.utils import cumulative_trapezoid
import jax.numpy as jnp

x = jnp.linspace(0, 1, 100)
y = jnp.ones_like(x)  # Constant function
cumulative = cumulative_trapezoid(y, x)
# cumulative[-1] ≈ 1.0
integrate_batch_cumulative(func, z_array, *, z_min=0.0, z_max=None, n_grid=4096)

Batch integration using a cumulative grid and interpolation. Suitable for a large number of query points.

Parameters:
  • func – Integrand function

  • z_array – Array of query points

  • z_min – Lower integration limit

  • z_max – Upper grid limit (defaults to the maximum of z_array)

  • n_grid – Number of grid points

Returns:

Array of integration results

Performance tip: When there are many query points, this method is faster than gauss_legendre_integrate_batch.

Segmented Integration

integrate_segmented(func, a, b, *, n_segments=8, order=12)

Segmented Gauss-Legendre integration, suitable for large integration intervals.

Parameters:
  • func – Integrand function

  • a – Lower integration limit

  • b – Upper integration limit

  • n_segments – Number of segments

  • order – Gauss-Legendre order per segment

Returns:

Integration result

Use case: When the integration interval is very large (e.g., [0, 1000]), a single Gauss-Legendre integration may lack precision; segmented integration improves accuracy.

Adaptive Simpson Integration

integrate_adaptive_simpson(func, a, b, *, num=64, max_num=4096, tol=1e-6)

Adaptive Simpson integration, automatically adjusting grid density.

Parameters:
  • func – Integrand function

  • a – Lower integration limit

  • b – Upper integration limit

  • num – Initial number of grid points (must be even)

  • max_num – Maximum number of grid points (must be an integer multiple of num)

  • tol – Convergence tolerance

Returns:

Integration result

Differentiation Tools

1D Gradient

gradient_1d(values, coords)

Compute the derivative on a 1D grid, supporting non-uniform grids.

Parameters:
  • values – Array of function values

  • coords – Array of coordinates

Returns:

Array of derivatives

Features: - Endpoints use 3-point Lagrange interpolation (second-order accuracy) - Interior points use second-order central differences - Supports non-uniform grids

Example:

from hicosmo.utils import gradient_1d
import jax.numpy as jnp

x = jnp.linspace(0, 2 * jnp.pi, 100)
y = jnp.sin(x)
dy_dx = gradient_1d(y, x)  # ≈ cos(x)

Finite Difference Gradient

finite_difference_grad(func, x, *, step=1e-5)

Compute the gradient of a scalar function using central finite differences.

Parameters:
  • func – Scalar function f(x) -> scalar

  • x – Evaluation point (array)

  • step – Difference step size

Returns:

Gradient array

Note: For differentiable functions, using jax.grad for automatic differentiation is recommended.

Automatic Differentiation Wrappers

grad(func)

Return the gradient function of a function (jax.grad wrapper).

jacobian(func)

Return the Jacobian matrix function of a function (jax.jacobian wrapper).

hessian(func)

Return the Hessian matrix function of a function (jax.hessian wrapper).

Example:

from hicosmo.utils import grad, hessian
import jax.numpy as jnp

def f(x):
    return jnp.sum(x**2)

grad_f = grad(f)
hess_f = hessian(f)

x = jnp.array([1.0, 2.0])
print(grad_f(x))  # [2., 4.]

Root-Finding Tools

Newton-Raphson Method

newton_root(func, x0, *, tol=1e-8, max_iter=64, deriv=None)

Solve f(x) = 0 using the Newton-Raphson method.

Parameters:
  • func – Target function f(x)

  • x0 – Initial guess

  • tol – Convergence tolerance

  • max_iter – Maximum number of iterations

  • deriv – Derivative function (optional, defaults to automatic differentiation)

Returns:

Approximate root

Example:

from hicosmo.utils import newton_root
import jax.numpy as jnp

# Solve x^2 - 2 = 0
root = newton_root(lambda x: x**2 - 2, jnp.array(1.5))
print(f"sqrt(2) ≈ {root:.10f}")  # 1.4142135624

Bisection Method

bisection_root(func, a, b, *, tol=1e-8, max_iter=128)

Solve f(x) = 0 using the bisection method (requires opposite signs at interval endpoints).

Parameters:
  • func – Target function f(x)

  • a – Left endpoint of the interval

  • b – Right endpoint of the interval

  • tol – Convergence tolerance

  • max_iter – Maximum number of iterations

Returns:

Approximate root (returns NaN if endpoints have the same sign)

Example:

from hicosmo.utils import bisection_root
import jax.numpy as jnp

# Solve x^3 - x - 1 = 0
root = bisection_root(
    lambda x: x**3 - x - 1,
    jnp.array(1.0),
    jnp.array(2.0)
)

ODE Solvers

RK4 Single Step

rk4_step(func, t, y, dt)

Perform one step of fourth-order Runge-Kutta integration.

Parameters:
  • func – ODE right-hand side f(t, y)

  • t – Current time

  • y – Current state

  • dt – Time step

Returns:

State at the next time step

RK4 Integrator

odeint_rk4(func, y0, t_grid)

Solve ODE y’ = f(t, y) on a fixed grid using the RK4 method.

Parameters:
  • func – ODE right-hand side f(t, y)

  • y0 – Initial condition

  • t_grid – Time grid

Returns:

Solution array (one row per time point)

Example:

from hicosmo.utils import odeint_rk4
import jax.numpy as jnp

# Solve y' = -y, y(0) = 1 (analytical solution: y = e^(-t))
def rhs(t, y):
    return -y

t = jnp.linspace(0, 5, 100)
y = odeint_rk4(rhs, jnp.array(1.0), t)
solve_ivp_rk4(func, y0, t0, t1, *, n_steps=256)

Solve an initial value problem on the interval [t0, t1] using fixed-step RK4.

Parameters:
  • func – ODE right-hand side f(t, y)

  • y0 – Initial condition

  • t0 – Start time

  • t1 – End time

  • n_steps – Number of steps

Returns:

(t_grid, y_grid) tuple

Cosmology-Specific Tools

Sound Horizon Computation

Based on the Eisenstein & Hu (1998) fitting formula.

sound_horizon_drag_eh98(H0, Omega_m, Omega_b, T_cmb)

Compute the sound horizon at the drag epoch r_d.

Parameters:
  • H0 – Hubble constant [km/s/Mpc]

  • Omega_m – Matter density parameter

  • Omega_b – Baryon density parameter

  • T_cmb – CMB temperature [K]

Returns:

Sound horizon r_d [Mpc]

Example:

from hicosmo.utils.jax_tools import sound_horizon_drag_eh98

r_d = sound_horizon_drag_eh98(
    H0=67.36,
    Omega_m=0.3153,
    Omega_b=0.0493,
    T_cmb=2.7255
)
print(f"r_d = {r_d:.2f} Mpc")  # ≈ 147 Mpc
sound_horizon_eh98(H0, Omega_m, Omega_b, T_cmb, z)

Compute the sound horizon at an arbitrary redshift z.

Parameters:
  • H0 – Hubble constant [km/s/Mpc]

  • Omega_m – Matter density parameter

  • Omega_b – Baryon density parameter

  • T_cmb – CMB temperature [K]

  • z – Redshift

Returns:

Sound horizon r_s(z) [Mpc]

Performance Comparison

Applicable scenarios for each integration method:

Method

Accuracy

Use Case

trapezoid

O(h²)

Quick rough estimates

simpson

O(h⁴)

Standard choice for smooth functions

gauss_legendre

High

High-precision integration over small intervals

integrate_batch_cumulative

O(h²)

Batch integration with many query points

integrate_segmented

High

High-precision integration over large intervals

JIT Compilation Notes

All functions are compatible with JAX JIT compilation:

from jax import jit
from hicosmo.utils import integrate_simpson

@jit
def my_function(params):
    result = integrate_simpson(
        lambda x: params[0] * x**2 + params[1],
        0, 1
    )
    return result

Notes:

  1. Integration limits [a, b] can be dynamic values

  2. Grid point counts (num, n_grid) must be static values

  3. Use static_argnums for parameters that need to be static

Complete Example

Computing Comoving Distance

import jax.numpy as jnp
from hicosmo.utils import integrate_batch_cumulative

# Cosmological parameters
H0 = 70.0  # km/s/Mpc
Omega_m = 0.3
c = 299792.458  # km/s

def E_z(z):
    """Hubble parameter E(z) = H(z)/H0"""
    return jnp.sqrt(Omega_m * (1 + z)**3 + (1 - Omega_m))

# Compute comoving distance for multiple redshifts
z_array = jnp.array([0.1, 0.5, 1.0, 2.0])
integrals = integrate_batch_cumulative(
    lambda z: 1.0 / E_z(z),
    z_array,
    n_grid=4096
)
d_c = (c / H0) * integrals

for z, d in zip(z_array, d_c):
    print(f"z = {z:.1f}: d_c = {d:.1f} Mpc")

References

  • Eisenstein, D. J., & Hu, W. (1998). Baryonic Features in the Matter Transfer Function. ApJ, 496, 605.