JAX Technical Deep Dive

This chapter provides a detailed introduction to JAX’s core concepts and technical advantages, as well as why HIcosmo chose JAX as its computation engine. If you are unfamiliar with JAX, this chapter will help you understand why JAX is a revolutionary tool for modern scientific computing.

What is JAX?

JAX is a high-performance numerical computing library developed by Google. It can be thought of as “differentiable, compilable, parallelizable NumPy”. Its name derives from Just-in-time compilation + Automatic differentiation + XLA (Accelerated Linear Algebra).

Official definition:

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

In plain terms:

  • If you know NumPy, you know JAX (the APIs are nearly identical)

  • JAX can automatically compute derivatives of any function (including higher-order derivatives)

  • JAX can automatically compile code into efficient machine code

  • JAX can transparently run on CPUs, GPUs, and TPUs

JAX is regarded as a differentiable programming language that uses Python as a “metaprogramming language” for building XLA computation graphs.

Functional Programming: The Core Philosophy of JAX

The design core of JAX is Functional Programming (FP), which fundamentally differs from traditional object-oriented programming (OOP) frameworks such as PyTorch.

Why Functional Programming?

The main reasons JAX adopts functional programming include:

  1. Mathematical consistency: Automatic differentiation is inherently a transformation of functions. Functional programming allows JAX to freely nest and compose grad (differentiation), jit (compilation), and vmap (vectorization) as higher-order functions, e.g.: jit(vmap(grad(f)))

  2. Easier compiler optimization: The XLA compiler requires a static, side-effect-free computation graph to perform operator fusion and memory optimization. State mutation in OOP (such as modifying class member variables) breaks the construction of such static graphs

  3. Deterministic random state: Traditional OOP frameworks often use global random state, while JAX mandates explicit passing of pseudo-random number generator (PRNG) keys, ensuring code reproducibility

  4. Separation of state and logic: Under the functional paradigm, model parameters are passed as function inputs rather than stored inside objects, making parameter distribution and updates more transparent

Pure Functions

Pure functions are the cornerstone of JAX’s ability to achieve high-performance computation.

Definition: A pure function is a function that does not alter external state and has no side effects.

Property: For the same input, a pure function must always produce the same output.

Why it matters: The predictability of pure functions allows JAX’s XLA compiler to deeply optimize code, enable just-in-time compilation (JIT), and easily achieve operator parallelism and sharding on GPUs/TPUs.

import jax.numpy as jnp

# Pure function: same input -> same output, no side effects
def pure_function(x, y):
    return jnp.sin(x) + jnp.cos(y)

# Impure function: depends on external state
global_state = 0
def impure_function(x):
    global global_state
    global_state += 1  # Side effect: modifies external state
    return x + global_state

Immutable Arrays

In JAX, arrays cannot be modified once created.

Difference from NumPy: Traditional NumPy allows in-place array updates; JAX arrays are immutable.

How it works: If you need to modify an element in an array, JAX does not directly overwrite the original memory. Instead, it creates a new array containing the modified content.

import numpy as np
import jax.numpy as jnp

# NumPy: in-place modification
np_arr = np.array([1, 2, 3])
np_arr[0] = 99  # Direct modification

# JAX: immutable, requires creating a new array
jax_arr = jnp.array([1, 2, 3])
# jax_arr[0] = 99  # TypeError!
jax_arr = jax_arr.at[0].set(99)  # Creates a new array

Design rationale: Immutability unlocks compiler optimizations, allowing XLA to safely perform operator fusion and memory reuse.

Gradient Computation Comparison: PyTorch vs JAX

Below are the core differences between PyTorch (object-oriented) and JAX (functional) in handling gradients:

PyTorch (Object-Oriented Paradigm):

Gradients are treated as an attribute of tensor objects, implemented through state mutation.

import torch

def f(x):
    return x**2 + 3*x + 5

x = torch.tensor([2.0], requires_grad=True)
loss = f(x)
loss.backward()  # Triggers AutoGrad and traces back through the computation graph

# Gradients are "stored" in the .grad attribute of variable x (state change)
print(x.grad.item())  # Output: 7.0

JAX (Functional Paradigm):

Differentiation is a function transformation that does not alter the state of any input or object.

import jax

def f(x):
    return x**2 + 3*x + 5

# jax.grad takes a function and returns another function (function transformation)
df_dx = jax.grad(f)

x_value = 2.0
# The gradient is returned directly as a return value, without modifying the original variable
derivative_value = df_dx(x_value)

print(derivative_value)  # Output: 7.0

Comparison summary:

Feature

Object-Oriented (PyTorch)

Functional (JAX)

State management

State stored inside objects, allows in-place modification

Pure functions, state passed via parameters, immutable

Gradient handling

Stored in .grad attribute via loss.backward()

Uses jax.grad to transform functions, returns gradient values directly

Random numbers

Relies on implicit global random state

Requires explicit PRNG key passing

Underlying core

Dynamic computation graph, easy to debug

XLA-compiled static computation graph, pursuing maximum performance

The Four Core Features of JAX

1. NumPy-Compatible API

JAX provides an API that is almost fully compatible with NumPy:

# NumPy code
import numpy as np

x = np.array([1.0, 2.0, 3.0])
y = np.sin(x) + np.cos(x)
z = np.dot(x, y)

# JAX code (just change the import)
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x) + jnp.cos(x)
z = jnp.dot(x, y)

Key differences:

Feature

NumPy

JAX

Mutability

Arrays are mutable (x[0] = 5)

Arrays are immutable (x = x.at[0].set(5))

Random numbers

Global state (np.random.rand())

Explicit key (jax.random.uniform(key))

Execution mode

Eager execution

Can be deferred (traced)

Hardware support

CPU only

CPU / GPU / TPU

2. Automatic Differentiation

Automatic differentiation is one of JAX’s most powerful features. It can automatically compute exact derivatives of arbitrarily complex functions.

What is automatic differentiation?

In scientific computing, we frequently need to compute derivatives (gradients) of functions. There are three traditional methods:

  1. Symbolic differentiation: Manual derivation using mathematical formulas, error-prone and unable to handle complex functions

  2. Numerical differentiation: Finite difference approximation, \(f'(x) \approx \frac{f(x+h) - f(x)}{h}\), low precision and low efficiency

  3. Automatic differentiation: Automatically traces the computation graph via the chain rule, exact and efficient

JAX automatic differentiation example:

import jax
import jax.numpy as jnp

# Define a complex function
def f(x):
    return jnp.sin(x) * jnp.exp(-x**2) + jnp.log(1 + x**2)

# Automatically obtain derivative functions
df = jax.grad(f)      # First derivative
d2f = jax.grad(df)    # Second derivative
d3f = jax.grad(d2f)   # Third derivative

# Compute derivative values
x = 1.0
print(f"f(x)   = {f(x):.6f}")
print(f"f'(x)  = {df(x):.6f}")
print(f"f''(x) = {d2f(x):.6f}")
print(f"f'''(x)= {d3f(x):.6f}")

Jacobian and Hessian of vector functions:

# Hessian matrix (used in optimization and Fisher matrix)
def scalar_func(x):
    return jnp.sum(x**2)

hessian = jax.hessian(scalar_func)

Advantages of automatic differentiation:

  • Exact: Results are exact mathematical derivatives, not numerical approximations

  • Efficient: Complexity is the same as the original function, no multiple function evaluations needed

  • General: Supports arbitrarily complex function compositions, including conditionals and loops

3. JIT Compilation and XLA Compiler

JIT compilation is the core source of JAX’s performance. It compiles Python code into optimized machine code.

How the XLA Compiler Works

XLA (Accelerated Linear Algebra) is a compiler specifically designed for linear algebra operations. It transforms programs written in JAX into optimized executable kernels for specific hardware (CPU, GPU, or TPU).

The workflow consists of the following stages:

  1. Tracing: When you apply jax.jit to a Python function and call it for the first time, JAX uses a “tracing” mechanism to record the sequence of operations performed by the function on arrays with specific shapes and types

  2. Building the computation graph: The traced information is transformed into an XLA computation graph (HLO, High-Level Optimizer representation)

  3. Hardware compilation and optimization: The XLA compiler performs comprehensive analysis on the graph, executing optimization strategies including dead code elimination, operator fusion, and more, finally compiling it into binary machine code that runs directly on hardware accelerators

Operator Fusion

Operator fusion is one of XLA’s most powerful optimization techniques.

Principle: In traditional vectorized programming, each simple operation (such as A * B + C) typically generates many intermediate temporary arrays. These intermediate results need to be written back to memory and then read again, causing significant memory bandwidth waste.

Implementation: XLA identifies consecutive operations in the computation graph and “fuses” them into a single GPU kernel. This means intermediate computation results are kept directly in the processor’s fast registers or cache, rather than frequently exchanging data with main memory.

Benefit: Significantly reduces memory bandwidth usage and memory allocation pressure, particularly avoiding OOM (out of memory) errors when processing high-dimensional cosmological data.

Analogy:

Think of traditional Python operations like ordering at a fast-food counter: every time you order a burger, the server runs back to the kitchen (calls a C library), makes it, brings it to you, then you order fries, and the server runs back again.

The XLA compiler and JIT are like a smart head chef: you tell them you want a combo meal (the entire function), they analyze the menu, decide to fry the fries and grill the patty simultaneously (operator fusion), and finally deliver the entire hot combo meal to you at once.

This “one-stop” processing greatly reduces the waiting time in transit (Python overhead and memory read/write).

JIT compilation example:

import jax
import jax.numpy as jnp
import time

# Define a compute-intensive function
def slow_function(x):
    for _ in range(100):
        x = jnp.sin(x) + jnp.cos(x)
    return x

# JIT-compiled version
fast_function = jax.jit(slow_function)

x = jnp.ones(10000)

# First call: includes compilation time
start = time.time()
_ = fast_function(x)
print(f"First call (with compilation): {time.time() - start:.4f}s")

# Subsequent calls: uses cached compiled result
start = time.time()
for _ in range(100):
    _ = fast_function(x)
print(f"Subsequent 100 calls: {time.time() - start:.4f}s")

4. Vectorization with vmap

vmap is JAX’s vectorization transformation that automatically converts scalar functions into batched functions.

Example:

import jax
import jax.numpy as jnp

# Define a function that processes a single sample
def single_sample_loglike(theta, data_point):
    mu, sigma = theta
    return -0.5 * ((data_point - mu) / sigma)**2 - jnp.log(sigma)

# Automatically batch using vmap
batch_loglike = jax.vmap(single_sample_loglike, in_axes=(None, 0))

theta = (0.0, 1.0)
data = jnp.array([0.1, 0.5, -0.3, 0.8, -0.2])

# Compute log-likelihood for all data points at once
log_likes = batch_loglike(theta, data)

Performance comparison (1000 data points):

  • Python loop: ~50ms

  • NumPy vectorized: ~2ms

  • JAX vmap + JIT: ~0.1ms

PRNG Random Number Management

JAX’s pseudo-random number generator (PRNG) management mechanism is a core embodiment of its functional programming philosophy. Unlike traditional frameworks such as NumPy or PyTorch that use “global random state”, JAX employs an explicit, stateless, and splittable random number management system.

Core Mechanism: Explicit Keys and Splitting

In JAX, randomness is not controlled by a global seed that changes over time, but is managed through a random key (PRNGKey).

  • PRNGKey: This is an array object containing the random state. All random functions (such as jax.random.normal) require a Key as input and generate deterministic output based on that Key

  • Key splitting: To generate multiple uncorrelated random number streams, JAX uses the jax.random.split function to decompose a master Key into multiple sub-Keys. This is similar to a tree structure: starting from a root seed, continuously branching into independent random substreams

Why Explicit Key Passing?

JAX requires explicit key passing fundamentally because of its pursuit of pure functions:

  1. Eliminating side effects: Traditional global random state is essentially a “global variable”. Every call to a random function changes this global state, which constitutes a “side effect”. For JIT and operator fusion, functions must be “pure”

  2. Reproducibility: Explicit keys ensure that regardless of where or in what order computation is executed, as long as the Key is the same, results are completely identical

  3. Parallelization and vectorization: In large-scale parallel computation (such as pmap or vmap), global state leads to races and unpredictable results. Explicit keys make PRNG generation easily parallelizable across multiple hardware devices

Comparison:

Feature

Traditional Frameworks (PyTorch/NumPy)

JAX

State management

Implicit, global. A background global state exists

Explicit, local. Keys are passed as parameters between functions

Side effects

Yes. Each random function call changes the global state

None. Calling functions does not change the input key, only returns results

Parallel safety

Difficult. Synchronizing global state in parallel environments is complex

Natively supported. Split provides independent Keys for each branch

Determinism

Depends on execution order

Depends only on the passed Key, independent of execution order

Code Example

import jax
import jax.numpy as jnp

# 1. Create an initial key
key = jax.random.PRNGKey(42)

# 2. Using the same key multiple times yields identical results
val1 = jax.random.normal(key)
val2 = jax.random.normal(key)
assert jnp.all(val1 == val2)  # Identical!

# 3. Correct approach: split the key
key, subkey = jax.random.split(key)
random_val = jax.random.normal(subkey)

# 4. Generate multiple independent random numbers
key, *subkeys = jax.random.split(key, 5)
random_vals = [jax.random.normal(k) for k in subkeys]

Analogy:

A traditional random generator is like an “ATM machine” – every time you withdraw money, the bank’s backend balance (global state) decreases, and you cannot return to the state before the withdrawal.

JAX’s PRNG mechanism is like a “duplicable treasure map” – the key is the coordinates on the map. If you copy the map (Split) and give it to another person, and you both follow the same coordinates and steps, the treasure (random numbers) you find will be exactly the same.

This makes it clear in large-scale distributed search (parallel computing) that each person knows exactly which area they are responsible for, without confusion.

JAX in Cosmology

JAX’s application in cosmological computation is triggering a paradigm revolution, driving the field from traditional “numerical integration-driven” approaches toward “fully differentiable inference (Differentiable Universe)”.

Specific Application Areas

Through its automatic differentiation and GPU acceleration capabilities, JAX has permeated every core aspect of cosmological research:

  1. High-dimensional Bayesian inference: When processing next-generation astronomical observation data (such as Stage IV surveys), model parameters often exceed 150. JAX makes the entire likelihood function fully differentiable, supporting efficient exploration of high-dimensional spaces

  2. Cosmological prediction emulators: Using JAX-based neural networks (such as CosmoPower-JAX) to replace traditional Boltzmann solvers (such as CAMB/CLASS). These emulators can predict matter power spectra and other observational effects at sub-millisecond speeds

  3. Differentiable cosmological simulations: JAX is used to write fully differentiable N-body simulations and gravitational lensing simulations (such as JAXtronomy). This enables researchers to perform field-level inference

  4. Fisher matrix computation: Traditional methods rely on unstable numerical differentiation. JAX can automatically compute exact second-order derivatives via jax.hessian

How Does It Accelerate Parameter Inference and MCMC Sampling?

JAX’s acceleration is not a single-dimensional improvement, but achieved through a combination of hardware optimization and algorithmic improvements:

  • From “random blind walk” to “gradient navigation”: Traditional MCMC (such as Metropolis-Hastings) randomly explores parameter space, with efficiency dropping dramatically as dimensionality increases. JAX enables the use of Hamiltonian Monte Carlo (HMC) or NUTS samplers, using gradients to guide the sampler along high-probability regions

  • XLA compiler and operator fusion: Through jax.jit, the XLA compiler transforms Python code into machine code optimized for GPU/TPU, eliminating Python interpretation latency

  • Neural network acceleration: Neural network emulators like CosmoPower-JAX are essentially linear algebra operations, which is exactly what GPUs excel at

Analogy:

Traditional cosmological sampling is like blindly groping through a fog-shrouded valley (random scattering of points), spending enormous effort at each step to probe the depth.

JAX is like providing researchers with a GPS navigation system and a powerful spotlight – gradient information indicates the fastest path to the summit, while GPU hardware acceleration is like equipping researchers with high-speed skateboards, allowing journeys that would take years to be completed in just days.

Cosmological JAX Ecosystem

The cosmology community has already formed a rich JAX software ecosystem:

Project

Core Use

Features

jax-cosmo

Core cosmology library

Provides differentiable background evolution, power spectra, and likelihood computation

CosmoPower-JAX

Neural network emulator

High-speed prediction of CMB and matter power spectra, supports HMC sampling with hundreds of parameters

candl

CMB likelihood analysis

Specialized for analyzing CMB power spectrum measurements (such as SPT, ACT)

JAXtronomy

Gravitational lensing simulation

JAX port of lenstronomy, supports GPU acceleration

microJAX

Microlensing simulation

First fully differentiable microlensing modeling framework

DISCO-DJ

Differentiable N-body simulation

Implements fully automatic differentiable cosmological simulations

PyBird-JAX

Effective Field Theory (EFT)

For fast processing of LSS data with EFT predictions

Performance Data and Benchmarks

Based on research data, JAX/XLA demonstrates acceleration capabilities far exceeding traditional frameworks:

Cosmological Inference Acceleration

Task

Traditional Method

JAX Method

Speedup

High-dimensional inference (157 parameters)

12 years (48-core CPU, nested sampling)

8 days (24 GPUs, gradient sampling)

10^5x

Cosmic shear analysis (37 parameters)

Weeks

Hours

~1000x

Central Bank of Chile economic model

12 hours (industrial server)

Seconds (consumer GPU)

~1000x

Scientific Computing Acceleration

Operation

NumPy/SciPy

JAX JIT

Speedup

Supernova likelihood computation

~50 ms

~0.05 ms

1000x

Distance calculation (1000 points)

150 ms

20 ms

7.5x

JAXtronomy ray tracing (vs CPU)

Baseline

GPU accelerated

120-140x

Solving 10,000-dimensional PDE

Baseline

JAX optimized

1000x + 30x memory savings

Why HIcosmo Chose JAX

1. MCMC Sampling Requires Efficient Gradient Computation

Modern MCMC samplers (such as NUTS) rely on Hamiltonian Monte Carlo methods, requiring gradient computation of parameters at each step.

Problems with traditional numerical differentiation:

  • For 10 parameters: 20 function evaluations needed

  • For 100 parameters: 200 function evaluations needed

  • Precision affected by epsilon choice

Advantages of JAX automatic differentiation:

  • Regardless of the number of parameters, only ~2 equivalent function evaluations needed

  • Results are exact mathematical derivatives

  • NumPyro NUTS sampler directly uses JAX’s automatic differentiation

2. Cosmological Computation Requires High-Precision Numerical Integration

Cosmological distance calculations involve integrals:

\[d_C(z) = \frac{c}{H_0} \int_0^z \frac{dz'}{E(z')}\]

HIcosmo approach: Pre-compute high-precision integration tables, extremely fast after JIT compilation

from hicosmo.models import LCDM
import jax.numpy as jnp

z_grid = jnp.linspace(0.01, 2.0, 1000)
params = {'H0': 70.0, 'Omega_m': 0.3}

# First call: compilation ~100ms
grid = LCDM.compute_grid_traced(z_grid, params)

# Subsequent calls: cached ~0.1ms
grid = LCDM.compute_grid_traced(z_grid, params)

3. Fisher Matrix Requires Second-Order Derivatives

The Fisher matrix is the expectation value of the Hessian of the likelihood function:

\[F_{ij} = -\left\langle \frac{\partial^2 \ln L}{\partial \theta_i \partial \theta_j} \right\rangle\]

Traditional method: Numerical second derivatives require \(O(n^2)\) function evaluations and are unstable

JAX method:

import jax

# Automatically obtain the Hessian matrix
hessian = jax.hessian(log_likelihood)

# Fisher matrix
fisher_matrix = -hessian(best_fit_params)

JAX vs Other Frameworks

Comparison with NumPy

Feature

NumPy

JAX

API

Native Python

Nearly identical to NumPy

Autodiff

Not supported

Full support

JIT compilation

Not supported

Supported

GPU support

Requires CuPy

Transparent support

Learning curve

Low

Low (if you know NumPy)

Comparison with TensorFlow/PyTorch

Feature

TensorFlow

PyTorch

JAX

Design goal

Deep learning

Deep learning

General scientific computing

API style

Computation graph

Dynamic graph

Function transformations

Scientific computing friendly

Average

Average

Excellent

NumPy compatible

Partial

Partial

Full

Functional programming

Limited

Limited

Core design

Why doesn’t HIcosmo use TensorFlow/PyTorch?

  1. API complexity: TensorFlow/PyTorch APIs are designed for neural networks, overly complex for pure numerical computation

  2. NumPy compatibility: JAX can directly use NumPy-style code with low migration cost

  3. Functional programming: Cosmological computations are pure functions, perfectly matching JAX’s function transformation philosophy

  4. Compilation efficiency: The XLA compiler has specialized optimizations for numerical computation

Common JAX Pitfalls and Solutions

1. Immutable Arrays

# Wrong: JAX arrays are immutable
x = jnp.array([1, 2, 3])
x[0] = 5  # TypeError!

# Correct: use .at[].set()
x = x.at[0].set(5)

2. Random Numbers Require Explicit Keys

# Wrong: no global random state
x = jax.random.normal()  # Error!

# Correct: explicitly pass a key
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape=(10,))

# Generate a new key
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(10,))

3. Dynamic Shapes in JIT

# Wrong: dynamic shape in JIT function
@jax.jit
def bad_func(x):
    return jnp.zeros(len(x))  # len(x) unknown at compile time

# Correct: use x.shape
@jax.jit
def good_func(x):
    return jnp.zeros(x.shape)

4. Python Control Flow

# Wrong: Python if in JIT
@jax.jit
def bad_func(x, flag):
    if flag:  # Python conditional
        return x + 1
    return x

# Correct: use jnp.where
@jax.jit
def good_func(x, flag):
    return jnp.where(flag, x + 1, x)

HIcosmo’s JAX Usage Patterns

Basic Usage

# HIcosmo wraps JAX; users typically do not need to use JAX directly
from hicosmo import hicosmo

# Behind this line of code, JAX is used for:
# - JIT compilation to accelerate distance calculations
# - Automatic differentiation to provide gradients for the NUTS sampler
# - vmap to parallelize multiple MCMC chains
inf = hicosmo("LCDM", "sn", ["H0", "Omega_m"])
samples = inf.run()

Advanced Usage

import jax
import jax.numpy as jnp
from hicosmo.models import LCDM

# Directly use JAX's automatic differentiation
def my_likelihood(params):
    grid = LCDM.compute_grid_traced(z_obs, params)
    d_L = grid['d_L']
    mu_theory = 5 * jnp.log10(d_L) + 25
    chi2 = jnp.sum(((mu_obs - mu_theory) / mu_err)**2)
    return -0.5 * chi2

# Automatically obtain gradients
grad_likelihood = jax.grad(my_likelihood)

# Automatically obtain the Hessian (for Fisher matrix)
hessian_likelihood = jax.hessian(my_likelihood)

Learning Resources

Official documentation:

Recommended tutorials:

Community resources:

  • Awesome JAX: Curated list of JAX resources

  • NumPyro: JAX-based probabilistic programming library (used by HIcosmo)

Cosmological JAX projects:

Summary

JAX provides HIcosmo with the following core capabilities:

  1. Automatic differentiation: NUTS sampler requires no manual gradients, supports arbitrarily complex likelihood functions

  2. JIT compilation: 10-1000x performance improvement, XLA operator fusion eliminates memory bottlenecks

  3. Vectorization: vmap enables efficient batch computation and multi-chain parallelism

  4. GPU support: Code runs transparently on GPUs without modification

  5. Composability: grad/jit/vmap can be arbitrarily combined

  6. Functional programming: Pure function design ensures reproducibility and parallel safety

Through this functional architecture, JAX transforms complex physical or mathematical models into optimizable objects, achieving 10^3 to 10^5x performance improvements over traditional methods in high-dimensional inference tasks such as cosmology.

For users, HIcosmo has encapsulated JAX’s complexity. You can use HIcosmo like a regular Python library while enjoying the performance benefits that JAX provides. Only when you need to customize likelihood functions or perform advanced analysis will you need to use the JAX API directly.

Next Steps