JAX Technical Deep Dive
This chapter provides a detailed introduction to JAX’s core concepts and technical advantages, as well as why HIcosmo chose JAX as its computation engine. If you are unfamiliar with JAX, this chapter will help you understand why JAX is a revolutionary tool for modern scientific computing.
What is JAX?
JAX is a high-performance numerical computing library developed by Google. It can be thought of as “differentiable, compilable, parallelizable NumPy”. Its name derives from Just-in-time compilation + Automatic differentiation + XLA (Accelerated Linear Algebra).
Official definition:
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
In plain terms:
If you know NumPy, you know JAX (the APIs are nearly identical)
JAX can automatically compute derivatives of any function (including higher-order derivatives)
JAX can automatically compile code into efficient machine code
JAX can transparently run on CPUs, GPUs, and TPUs
JAX is regarded as a differentiable programming language that uses Python as a “metaprogramming language” for building XLA computation graphs.
Functional Programming: The Core Philosophy of JAX
The design core of JAX is Functional Programming (FP), which fundamentally differs from traditional object-oriented programming (OOP) frameworks such as PyTorch.
Why Functional Programming?
The main reasons JAX adopts functional programming include:
Mathematical consistency: Automatic differentiation is inherently a transformation of functions. Functional programming allows JAX to freely nest and compose
grad(differentiation),jit(compilation), andvmap(vectorization) as higher-order functions, e.g.:jit(vmap(grad(f)))Easier compiler optimization: The XLA compiler requires a static, side-effect-free computation graph to perform operator fusion and memory optimization. State mutation in OOP (such as modifying class member variables) breaks the construction of such static graphs
Deterministic random state: Traditional OOP frameworks often use global random state, while JAX mandates explicit passing of pseudo-random number generator (PRNG) keys, ensuring code reproducibility
Separation of state and logic: Under the functional paradigm, model parameters are passed as function inputs rather than stored inside objects, making parameter distribution and updates more transparent
Pure Functions
Pure functions are the cornerstone of JAX’s ability to achieve high-performance computation.
Definition: A pure function is a function that does not alter external state and has no side effects.
Property: For the same input, a pure function must always produce the same output.
Why it matters: The predictability of pure functions allows JAX’s XLA compiler to deeply optimize code, enable just-in-time compilation (JIT), and easily achieve operator parallelism and sharding on GPUs/TPUs.
import jax.numpy as jnp
# Pure function: same input -> same output, no side effects
def pure_function(x, y):
return jnp.sin(x) + jnp.cos(y)
# Impure function: depends on external state
global_state = 0
def impure_function(x):
global global_state
global_state += 1 # Side effect: modifies external state
return x + global_state
Immutable Arrays
In JAX, arrays cannot be modified once created.
Difference from NumPy: Traditional NumPy allows in-place array updates; JAX arrays are immutable.
How it works: If you need to modify an element in an array, JAX does not directly overwrite the original memory. Instead, it creates a new array containing the modified content.
import numpy as np
import jax.numpy as jnp
# NumPy: in-place modification
np_arr = np.array([1, 2, 3])
np_arr[0] = 99 # Direct modification
# JAX: immutable, requires creating a new array
jax_arr = jnp.array([1, 2, 3])
# jax_arr[0] = 99 # TypeError!
jax_arr = jax_arr.at[0].set(99) # Creates a new array
Design rationale: Immutability unlocks compiler optimizations, allowing XLA to safely perform operator fusion and memory reuse.
Gradient Computation Comparison: PyTorch vs JAX
Below are the core differences between PyTorch (object-oriented) and JAX (functional) in handling gradients:
PyTorch (Object-Oriented Paradigm):
Gradients are treated as an attribute of tensor objects, implemented through state mutation.
import torch
def f(x):
return x**2 + 3*x + 5
x = torch.tensor([2.0], requires_grad=True)
loss = f(x)
loss.backward() # Triggers AutoGrad and traces back through the computation graph
# Gradients are "stored" in the .grad attribute of variable x (state change)
print(x.grad.item()) # Output: 7.0
JAX (Functional Paradigm):
Differentiation is a function transformation that does not alter the state of any input or object.
import jax
def f(x):
return x**2 + 3*x + 5
# jax.grad takes a function and returns another function (function transformation)
df_dx = jax.grad(f)
x_value = 2.0
# The gradient is returned directly as a return value, without modifying the original variable
derivative_value = df_dx(x_value)
print(derivative_value) # Output: 7.0
Comparison summary:
Feature |
Object-Oriented (PyTorch) |
Functional (JAX) |
|---|---|---|
State management |
State stored inside objects, allows in-place modification |
Pure functions, state passed via parameters, immutable |
Gradient handling |
Stored in |
Uses |
Random numbers |
Relies on implicit global random state |
Requires explicit PRNG key passing |
Underlying core |
Dynamic computation graph, easy to debug |
XLA-compiled static computation graph, pursuing maximum performance |
The Four Core Features of JAX
1. NumPy-Compatible API
JAX provides an API that is almost fully compatible with NumPy:
# NumPy code
import numpy as np
x = np.array([1.0, 2.0, 3.0])
y = np.sin(x) + np.cos(x)
z = np.dot(x, y)
# JAX code (just change the import)
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x) + jnp.cos(x)
z = jnp.dot(x, y)
Key differences:
Feature |
NumPy |
JAX |
|---|---|---|
Mutability |
Arrays are mutable ( |
Arrays are immutable ( |
Random numbers |
Global state ( |
Explicit key ( |
Execution mode |
Eager execution |
Can be deferred (traced) |
Hardware support |
CPU only |
CPU / GPU / TPU |
2. Automatic Differentiation
Automatic differentiation is one of JAX’s most powerful features. It can automatically compute exact derivatives of arbitrarily complex functions.
What is automatic differentiation?
In scientific computing, we frequently need to compute derivatives (gradients) of functions. There are three traditional methods:
Symbolic differentiation: Manual derivation using mathematical formulas, error-prone and unable to handle complex functions
Numerical differentiation: Finite difference approximation, \(f'(x) \approx \frac{f(x+h) - f(x)}{h}\), low precision and low efficiency
Automatic differentiation: Automatically traces the computation graph via the chain rule, exact and efficient
JAX automatic differentiation example:
import jax
import jax.numpy as jnp
# Define a complex function
def f(x):
return jnp.sin(x) * jnp.exp(-x**2) + jnp.log(1 + x**2)
# Automatically obtain derivative functions
df = jax.grad(f) # First derivative
d2f = jax.grad(df) # Second derivative
d3f = jax.grad(d2f) # Third derivative
# Compute derivative values
x = 1.0
print(f"f(x) = {f(x):.6f}")
print(f"f'(x) = {df(x):.6f}")
print(f"f''(x) = {d2f(x):.6f}")
print(f"f'''(x)= {d3f(x):.6f}")
Jacobian and Hessian of vector functions:
# Hessian matrix (used in optimization and Fisher matrix)
def scalar_func(x):
return jnp.sum(x**2)
hessian = jax.hessian(scalar_func)
Advantages of automatic differentiation:
Exact: Results are exact mathematical derivatives, not numerical approximations
Efficient: Complexity is the same as the original function, no multiple function evaluations needed
General: Supports arbitrarily complex function compositions, including conditionals and loops
3. JIT Compilation and XLA Compiler
JIT compilation is the core source of JAX’s performance. It compiles Python code into optimized machine code.
How the XLA Compiler Works
XLA (Accelerated Linear Algebra) is a compiler specifically designed for linear algebra operations. It transforms programs written in JAX into optimized executable kernels for specific hardware (CPU, GPU, or TPU).
The workflow consists of the following stages:
Tracing: When you apply
jax.jitto a Python function and call it for the first time, JAX uses a “tracing” mechanism to record the sequence of operations performed by the function on arrays with specific shapes and typesBuilding the computation graph: The traced information is transformed into an XLA computation graph (HLO, High-Level Optimizer representation)
Hardware compilation and optimization: The XLA compiler performs comprehensive analysis on the graph, executing optimization strategies including dead code elimination, operator fusion, and more, finally compiling it into binary machine code that runs directly on hardware accelerators
Operator Fusion
Operator fusion is one of XLA’s most powerful optimization techniques.
Principle: In traditional vectorized programming, each simple operation (such as A * B + C) typically generates many intermediate temporary arrays. These intermediate results need to be written back to memory and then read again, causing significant memory bandwidth waste.
Implementation: XLA identifies consecutive operations in the computation graph and “fuses” them into a single GPU kernel. This means intermediate computation results are kept directly in the processor’s fast registers or cache, rather than frequently exchanging data with main memory.
Benefit: Significantly reduces memory bandwidth usage and memory allocation pressure, particularly avoiding OOM (out of memory) errors when processing high-dimensional cosmological data.
Analogy:
Think of traditional Python operations like ordering at a fast-food counter: every time you order a burger, the server runs back to the kitchen (calls a C library), makes it, brings it to you, then you order fries, and the server runs back again.
The XLA compiler and JIT are like a smart head chef: you tell them you want a combo meal (the entire function), they analyze the menu, decide to fry the fries and grill the patty simultaneously (operator fusion), and finally deliver the entire hot combo meal to you at once.
This “one-stop” processing greatly reduces the waiting time in transit (Python overhead and memory read/write).
JIT compilation example:
import jax
import jax.numpy as jnp
import time
# Define a compute-intensive function
def slow_function(x):
for _ in range(100):
x = jnp.sin(x) + jnp.cos(x)
return x
# JIT-compiled version
fast_function = jax.jit(slow_function)
x = jnp.ones(10000)
# First call: includes compilation time
start = time.time()
_ = fast_function(x)
print(f"First call (with compilation): {time.time() - start:.4f}s")
# Subsequent calls: uses cached compiled result
start = time.time()
for _ in range(100):
_ = fast_function(x)
print(f"Subsequent 100 calls: {time.time() - start:.4f}s")
4. Vectorization with vmap
vmap is JAX’s vectorization transformation that automatically converts scalar functions into batched functions.
Example:
import jax
import jax.numpy as jnp
# Define a function that processes a single sample
def single_sample_loglike(theta, data_point):
mu, sigma = theta
return -0.5 * ((data_point - mu) / sigma)**2 - jnp.log(sigma)
# Automatically batch using vmap
batch_loglike = jax.vmap(single_sample_loglike, in_axes=(None, 0))
theta = (0.0, 1.0)
data = jnp.array([0.1, 0.5, -0.3, 0.8, -0.2])
# Compute log-likelihood for all data points at once
log_likes = batch_loglike(theta, data)
Performance comparison (1000 data points):
Python loop: ~50ms
NumPy vectorized: ~2ms
JAX vmap + JIT: ~0.1ms
PRNG Random Number Management
JAX’s pseudo-random number generator (PRNG) management mechanism is a core embodiment of its functional programming philosophy. Unlike traditional frameworks such as NumPy or PyTorch that use “global random state”, JAX employs an explicit, stateless, and splittable random number management system.
Core Mechanism: Explicit Keys and Splitting
In JAX, randomness is not controlled by a global seed that changes over time, but is managed through a random key (PRNGKey).
PRNGKey: This is an array object containing the random state. All random functions (such as
jax.random.normal) require a Key as input and generate deterministic output based on that KeyKey splitting: To generate multiple uncorrelated random number streams, JAX uses the
jax.random.splitfunction to decompose a master Key into multiple sub-Keys. This is similar to a tree structure: starting from a root seed, continuously branching into independent random substreams
Why Explicit Key Passing?
JAX requires explicit key passing fundamentally because of its pursuit of pure functions:
Eliminating side effects: Traditional global random state is essentially a “global variable”. Every call to a random function changes this global state, which constitutes a “side effect”. For JIT and operator fusion, functions must be “pure”
Reproducibility: Explicit keys ensure that regardless of where or in what order computation is executed, as long as the Key is the same, results are completely identical
Parallelization and vectorization: In large-scale parallel computation (such as
pmaporvmap), global state leads to races and unpredictable results. Explicit keys make PRNG generation easily parallelizable across multiple hardware devices
Comparison:
Feature |
Traditional Frameworks (PyTorch/NumPy) |
JAX |
|---|---|---|
State management |
Implicit, global. A background global state exists |
Explicit, local. Keys are passed as parameters between functions |
Side effects |
Yes. Each random function call changes the global state |
None. Calling functions does not change the input key, only returns results |
Parallel safety |
Difficult. Synchronizing global state in parallel environments is complex |
Natively supported. Split provides independent Keys for each branch |
Determinism |
Depends on execution order |
Depends only on the passed Key, independent of execution order |
Code Example
import jax
import jax.numpy as jnp
# 1. Create an initial key
key = jax.random.PRNGKey(42)
# 2. Using the same key multiple times yields identical results
val1 = jax.random.normal(key)
val2 = jax.random.normal(key)
assert jnp.all(val1 == val2) # Identical!
# 3. Correct approach: split the key
key, subkey = jax.random.split(key)
random_val = jax.random.normal(subkey)
# 4. Generate multiple independent random numbers
key, *subkeys = jax.random.split(key, 5)
random_vals = [jax.random.normal(k) for k in subkeys]
Analogy:
A traditional random generator is like an “ATM machine” – every time you withdraw money, the bank’s backend balance (global state) decreases, and you cannot return to the state before the withdrawal.
JAX’s PRNG mechanism is like a “duplicable treasure map” – the key is the coordinates on the map. If you copy the map (Split) and give it to another person, and you both follow the same coordinates and steps, the treasure (random numbers) you find will be exactly the same.
This makes it clear in large-scale distributed search (parallel computing) that each person knows exactly which area they are responsible for, without confusion.
JAX in Cosmology
JAX’s application in cosmological computation is triggering a paradigm revolution, driving the field from traditional “numerical integration-driven” approaches toward “fully differentiable inference (Differentiable Universe)”.
Specific Application Areas
Through its automatic differentiation and GPU acceleration capabilities, JAX has permeated every core aspect of cosmological research:
High-dimensional Bayesian inference: When processing next-generation astronomical observation data (such as Stage IV surveys), model parameters often exceed 150. JAX makes the entire likelihood function fully differentiable, supporting efficient exploration of high-dimensional spaces
Cosmological prediction emulators: Using JAX-based neural networks (such as CosmoPower-JAX) to replace traditional Boltzmann solvers (such as CAMB/CLASS). These emulators can predict matter power spectra and other observational effects at sub-millisecond speeds
Differentiable cosmological simulations: JAX is used to write fully differentiable N-body simulations and gravitational lensing simulations (such as JAXtronomy). This enables researchers to perform field-level inference
Fisher matrix computation: Traditional methods rely on unstable numerical differentiation. JAX can automatically compute exact second-order derivatives via
jax.hessian
How Does It Accelerate Parameter Inference and MCMC Sampling?
JAX’s acceleration is not a single-dimensional improvement, but achieved through a combination of hardware optimization and algorithmic improvements:
From “random blind walk” to “gradient navigation”: Traditional MCMC (such as Metropolis-Hastings) randomly explores parameter space, with efficiency dropping dramatically as dimensionality increases. JAX enables the use of Hamiltonian Monte Carlo (HMC) or NUTS samplers, using gradients to guide the sampler along high-probability regions
XLA compiler and operator fusion: Through
jax.jit, the XLA compiler transforms Python code into machine code optimized for GPU/TPU, eliminating Python interpretation latencyNeural network acceleration: Neural network emulators like CosmoPower-JAX are essentially linear algebra operations, which is exactly what GPUs excel at
Analogy:
Traditional cosmological sampling is like blindly groping through a fog-shrouded valley (random scattering of points), spending enormous effort at each step to probe the depth.
JAX is like providing researchers with a GPS navigation system and a powerful spotlight – gradient information indicates the fastest path to the summit, while GPU hardware acceleration is like equipping researchers with high-speed skateboards, allowing journeys that would take years to be completed in just days.
Cosmological JAX Ecosystem
The cosmology community has already formed a rich JAX software ecosystem:
Project |
Core Use |
Features |
|---|---|---|
jax-cosmo |
Core cosmology library |
Provides differentiable background evolution, power spectra, and likelihood computation |
CosmoPower-JAX |
Neural network emulator |
High-speed prediction of CMB and matter power spectra, supports HMC sampling with hundreds of parameters |
candl |
CMB likelihood analysis |
Specialized for analyzing CMB power spectrum measurements (such as SPT, ACT) |
JAXtronomy |
Gravitational lensing simulation |
JAX port of lenstronomy, supports GPU acceleration |
microJAX |
Microlensing simulation |
First fully differentiable microlensing modeling framework |
DISCO-DJ |
Differentiable N-body simulation |
Implements fully automatic differentiable cosmological simulations |
PyBird-JAX |
Effective Field Theory (EFT) |
For fast processing of LSS data with EFT predictions |
Performance Data and Benchmarks
Based on research data, JAX/XLA demonstrates acceleration capabilities far exceeding traditional frameworks:
Cosmological Inference Acceleration
Task |
Traditional Method |
JAX Method |
Speedup |
|---|---|---|---|
High-dimensional inference (157 parameters) |
12 years (48-core CPU, nested sampling) |
8 days (24 GPUs, gradient sampling) |
10^5x |
Cosmic shear analysis (37 parameters) |
Weeks |
Hours |
~1000x |
Central Bank of Chile economic model |
12 hours (industrial server) |
Seconds (consumer GPU) |
~1000x |
Scientific Computing Acceleration
Operation |
NumPy/SciPy |
JAX JIT |
Speedup |
|---|---|---|---|
Supernova likelihood computation |
~50 ms |
~0.05 ms |
1000x |
Distance calculation (1000 points) |
150 ms |
20 ms |
7.5x |
JAXtronomy ray tracing (vs CPU) |
Baseline |
GPU accelerated |
120-140x |
Solving 10,000-dimensional PDE |
Baseline |
JAX optimized |
1000x + 30x memory savings |
Why HIcosmo Chose JAX
1. MCMC Sampling Requires Efficient Gradient Computation
Modern MCMC samplers (such as NUTS) rely on Hamiltonian Monte Carlo methods, requiring gradient computation of parameters at each step.
Problems with traditional numerical differentiation:
For 10 parameters: 20 function evaluations needed
For 100 parameters: 200 function evaluations needed
Precision affected by epsilon choice
Advantages of JAX automatic differentiation:
Regardless of the number of parameters, only ~2 equivalent function evaluations needed
Results are exact mathematical derivatives
NumPyro NUTS sampler directly uses JAX’s automatic differentiation
2. Cosmological Computation Requires High-Precision Numerical Integration
Cosmological distance calculations involve integrals:
HIcosmo approach: Pre-compute high-precision integration tables, extremely fast after JIT compilation
from hicosmo.models import LCDM
import jax.numpy as jnp
z_grid = jnp.linspace(0.01, 2.0, 1000)
params = {'H0': 70.0, 'Omega_m': 0.3}
# First call: compilation ~100ms
grid = LCDM.compute_grid_traced(z_grid, params)
# Subsequent calls: cached ~0.1ms
grid = LCDM.compute_grid_traced(z_grid, params)
3. Fisher Matrix Requires Second-Order Derivatives
The Fisher matrix is the expectation value of the Hessian of the likelihood function:
Traditional method: Numerical second derivatives require \(O(n^2)\) function evaluations and are unstable
JAX method:
import jax
# Automatically obtain the Hessian matrix
hessian = jax.hessian(log_likelihood)
# Fisher matrix
fisher_matrix = -hessian(best_fit_params)
JAX vs Other Frameworks
Comparison with NumPy
Feature |
NumPy |
JAX |
|---|---|---|
API |
Native Python |
Nearly identical to NumPy |
Autodiff |
Not supported |
Full support |
JIT compilation |
Not supported |
Supported |
GPU support |
Requires CuPy |
Transparent support |
Learning curve |
Low |
Low (if you know NumPy) |
Comparison with TensorFlow/PyTorch
Feature |
TensorFlow |
PyTorch |
JAX |
|---|---|---|---|
Design goal |
Deep learning |
Deep learning |
General scientific computing |
API style |
Computation graph |
Dynamic graph |
Function transformations |
Scientific computing friendly |
Average |
Average |
Excellent |
NumPy compatible |
Partial |
Partial |
Full |
Functional programming |
Limited |
Limited |
Core design |
Why doesn’t HIcosmo use TensorFlow/PyTorch?
API complexity: TensorFlow/PyTorch APIs are designed for neural networks, overly complex for pure numerical computation
NumPy compatibility: JAX can directly use NumPy-style code with low migration cost
Functional programming: Cosmological computations are pure functions, perfectly matching JAX’s function transformation philosophy
Compilation efficiency: The XLA compiler has specialized optimizations for numerical computation
Common JAX Pitfalls and Solutions
1. Immutable Arrays
# Wrong: JAX arrays are immutable
x = jnp.array([1, 2, 3])
x[0] = 5 # TypeError!
# Correct: use .at[].set()
x = x.at[0].set(5)
2. Random Numbers Require Explicit Keys
# Wrong: no global random state
x = jax.random.normal() # Error!
# Correct: explicitly pass a key
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape=(10,))
# Generate a new key
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(10,))
3. Dynamic Shapes in JIT
# Wrong: dynamic shape in JIT function
@jax.jit
def bad_func(x):
return jnp.zeros(len(x)) # len(x) unknown at compile time
# Correct: use x.shape
@jax.jit
def good_func(x):
return jnp.zeros(x.shape)
4. Python Control Flow
# Wrong: Python if in JIT
@jax.jit
def bad_func(x, flag):
if flag: # Python conditional
return x + 1
return x
# Correct: use jnp.where
@jax.jit
def good_func(x, flag):
return jnp.where(flag, x + 1, x)
HIcosmo’s JAX Usage Patterns
Basic Usage
# HIcosmo wraps JAX; users typically do not need to use JAX directly
from hicosmo import hicosmo
# Behind this line of code, JAX is used for:
# - JIT compilation to accelerate distance calculations
# - Automatic differentiation to provide gradients for the NUTS sampler
# - vmap to parallelize multiple MCMC chains
inf = hicosmo("LCDM", "sn", ["H0", "Omega_m"])
samples = inf.run()
Advanced Usage
import jax
import jax.numpy as jnp
from hicosmo.models import LCDM
# Directly use JAX's automatic differentiation
def my_likelihood(params):
grid = LCDM.compute_grid_traced(z_obs, params)
d_L = grid['d_L']
mu_theory = 5 * jnp.log10(d_L) + 25
chi2 = jnp.sum(((mu_obs - mu_theory) / mu_err)**2)
return -0.5 * chi2
# Automatically obtain gradients
grad_likelihood = jax.grad(my_likelihood)
# Automatically obtain the Hessian (for Fisher matrix)
hessian_likelihood = jax.hessian(my_likelihood)
Learning Resources
Official documentation:
Recommended tutorials:
Community resources:
Awesome JAX: Curated list of JAX resources
NumPyro: JAX-based probabilistic programming library (used by HIcosmo)
Cosmological JAX projects:
jax-cosmo: Differentiable cosmology
CosmoPower-JAX: Neural network emulator
JAXtronomy: Gravitational lensing simulation
Summary
JAX provides HIcosmo with the following core capabilities:
Automatic differentiation: NUTS sampler requires no manual gradients, supports arbitrarily complex likelihood functions
JIT compilation: 10-1000x performance improvement, XLA operator fusion eliminates memory bottlenecks
Vectorization: vmap enables efficient batch computation and multi-chain parallelism
GPU support: Code runs transparently on GPUs without modification
Composability: grad/jit/vmap can be arbitrarily combined
Functional programming: Pure function design ensures reproducibility and parallel safety
Through this functional architecture, JAX transforms complex physical or mathematical models into optimizable objects, achieving 10^3 to 10^5x performance improvements over traditional methods in high-dimensional inference tasks such as cosmology.
For users, HIcosmo has encapsulated JAX’s complexity. You can use HIcosmo like a regular Python library while enjoying the performance benefits that JAX provides. Only when you need to customize likelihood functions or perform advanced analysis will you need to use the JAX API directly.
Next Steps
Core Concepts: Understand HIcosmo’s architecture design
Samplers: Learn about MCMC configuration
Fisher Forecasts: Learn about Fisher matrix analysis