Parallelism and Acceleration Configuration

HIcosmo is built on JAX, with native support for multi-core CPU parallelism and GPU acceleration. This chapter explains in detail how to configure the parallel environment to fully leverage your hardware.

Quick Start

Configure parallelism in one line:

import hicosmo as hc

# Most common: 8-core parallelism
hc.init(8)

# Auto-detect optimal configuration
hc.init()

# GPU mode
hc.init("GPU")

Important

You must call hc.init() before importing JAX!

JAX’s device configuration is fixed at the time of first import and cannot be changed afterwards. Therefore hc.init() should be the very first line of your script.

# ✅ Correct order
import hicosmo as hc
hc.init(8)

from hicosmo.models import LCDM  # JAX is imported here
from hicosmo.samplers import MCMC

# ❌ Wrong order
from hicosmo.models import LCDM  # JAX already imported, device count fixed to 1
import hicosmo as hc
hc.init(8)  # ⚠️ Warning: configuration ignored

Initialization API

hc.init() Function

import hicosmo as hc

hc.init(
    num_devices='auto',  # Number of devices
    device=None,         # Device type ('GPU' or None)
    verbose=True,        # Whether to display configuration info
    force=True           # Whether to force-override existing configuration
)

Parameter Description:

Parameter

Type

Description

num_devices

int / str / None

Number of parallel devices:
- int: Specified count (e.g., 8)
- 'auto': Auto-detect (default, up to 8)
- 'GPU': GPU mode
- None: Single-device vectorized mode

device

str / None

Explicit device type:
- None: CPU mode (default)
- 'GPU' / 'cuda': GPU mode

verbose

bool

Whether to print configuration info

force

bool

Whether to override existing XLA_FLAGS

Multi-Core CPU Parallelism

Basic Configuration

import hicosmo as hc

# Use 8 CPU cores
hc.init(8)

# Running MCMC will automatically use 8 parallel chains
from hicosmo.samplers import MCMC
mcmc = MCMC(params, likelihood, chain_name='test')
mcmc.run(num_samples=2000, num_chains=8)  # 8 chains run in parallel

How It Works

When you call hc.init(8), HIcosmo will:

  1. Set XLA_FLAGS: Configure JAX to use 8 logical CPU devices

  2. Configure thread pool: Assign CPU cores to each device

  3. Set up NumPyro: Inform NumPyro to use 8 devices

┌─────────────────────────────────────────────────────────┐
│                     CPU (8 cores)                       │
├─────────┬─────────┬─────────┬─────────┬─────────────────┤
│ Device 0│ Device 1│ Device 2│ Device 3│ ... Device 7    │
│ (Chain 1)│ (Chain 2)│ (Chain 3)│ (Chain 4)│ ... (Chain 8)│
└─────────┴─────────┴─────────┴─────────┴─────────────────┘

GPU Acceleration

Basic Configuration

import hicosmo as hc

# Auto-detect GPU
hc.init("GPU")

# Or specify explicitly
hc.init(device="GPU")

Multi-GPU Configuration

import hicosmo as hc

# Use 4 GPUs (if available)
hc.init(4, device="GPU")

# Auto-detect all available GPUs
hc.init("GPU")  # Automatically detects GPU count

GPU vs CPU Selection

Scenario

Recommendation

Reason

Simple likelihoods (SNe, BAO)

CPU

GPU startup overhead exceeds computation benefit

Complex likelihoods (CMB power spectrum)

GPU

Large-scale matrix operations benefit from GPU

Large-scale Fisher matrices

GPU

Matrix inversion is well-suited for GPU

Long MCMC runs (>100k samples)

GPU

Cumulative benefit is significant

GPU Environment Setup

Make sure the JAX GPU version is installed:

# CUDA 12
pip install jax[cuda12]

# CUDA 11
pip install jax[cuda11_pip]

Verify GPU availability:

import jax
print(jax.devices())
# [cuda(id=0), cuda(id=1), ...]  # GPU available
# [CpuDevice(id=0)]              # CPU only

Cluster Configuration

SLURM Cluster

Running HIcosmo on a SLURM cluster:

#!/bin/bash
#SBATCH --job-name=hicosmo
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G
#SBATCH --time=24:00:00

# Activate environment
source activate hicosmo

# Run script
python my_mcmc_script.py

Corresponding Python script:

import hicosmo as hc

# Use the 8 cores allocated by SLURM
hc.init(8)

# ... rest of the code

PBS/Torque Cluster

#!/bin/bash
#PBS -N hicosmo
#PBS -l nodes=1:ppn=8
#PBS -l mem=32gb
#PBS -l walltime=24:00:00

cd $PBS_O_WORKDIR
source activate hicosmo
python my_mcmc_script.py

GPU Cluster

#!/bin/bash
#SBATCH --job-name=hicosmo_gpu
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --mem=64G
#SBATCH --time=24:00:00

source activate hicosmo
python my_mcmc_script.py
import hicosmo as hc

# Use 4 GPUs
hc.init(4, device="GPU")

Environment Variable Configuration

HIcosmo automatically manages the following environment variables; manual setup is usually not needed:

Environment Variable

Description

XLA_FLAGS

JAX/XLA compiler configuration, including device count

JAX_ENABLE_X64

Enable 64-bit precision (default True)

JAX_NUM_THREADS

Number of threads per device

Manual Configuration (Advanced)

For finer-grained control:

# Set in shell
export XLA_FLAGS="--xla_force_host_platform_device_count=8"
export JAX_ENABLE_X64=True
python my_script.py
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

# Must be after setting environment variables and before importing JAX
import hicosmo as hc
hc.init(8, force=False)  # Do not override already-set XLA_FLAGS

Viewing Current Configuration

import hicosmo as hc

hc.init(8)

# View configuration status
status = hc.Config.status()
print(status)
# {
#     'initialized': True,
#     'config': {'num_devices': 8, 'device_type': 'cpu', ...},
#     'system_cores': 16,
#     'jax_devices': 8,
#     'jax_device_list': ['CpuDevice(id=0)', 'CpuDevice(id=1)', ...]
# }

# View JAX devices
import jax
print(f"Available devices: {len(jax.devices())}")
for d in jax.devices():
    print(f"  {d}")

Performance Tuning Tips

  1. Match the number of chains to the number of devices

    hc.init(8)
    # num_chains should be equal to or less than the device count
    mcmc.run(num_samples=2000, num_chains=8)  # ✅ Optimal
    mcmc.run(num_samples=2000, num_chains=4)  # ✅ OK, but wastes 4 devices
    mcmc.run(num_samples=2000, num_chains=16) # ⚠️ Exceeds device count, will queue
    
  2. Warm up JIT compilation

    # The first call triggers JIT compilation (slower)
    result = likelihood(H0=70, Omega_m=0.3)
    
    # Subsequent calls use the cache (fast)
    result = likelihood(H0=71, Omega_m=0.3)
    
  3. Memory management

    # Monitor memory during large-scale sampling
    import jax
    jax.clear_caches()  # Clear JIT caches to free memory
    
  4. Avoid frequent initialization

    # ❌ Do not initialize inside a loop
    for i in range(10):
        hc.init(8)  # Will print a warning each time
        ...
    
    # ✅ Initialize only once
    hc.init(8)
    for i in range(10):
        ...
    

FAQ

Q: Why was my configuration ignored?

The most common cause is that JAX has already been imported:

from hicosmo.models import LCDM  # ❌ JAX is imported here
import hicosmo as hc
hc.init(8)  # ⚠️ Warning: configuration ignored

Solution: Make sure hc.init() is called before all JAX-related imports.

Q: GPU not detected?

  1. Check if the JAX GPU version is installed:

    import jax
    print(jax.devices())  # Should show cuda devices
    
  2. Check the CUDA environment:

    nvidia-smi  # Should display GPU information
    
  3. Reinstall the JAX GPU version:

    pip uninstall jax jaxlib
    pip install jax[cuda12]
    

Q: How to use in a Jupyter Notebook?

In the first cell of the notebook:

# Cell 1 - Must be the first cell!
import hicosmo as hc
hc.init(8)

# Cell 2 - Then import other modules
from hicosmo.models import LCDM
from hicosmo.samplers import MCMC

Warning

If you restart the kernel, you need to re-run hc.init().

Q: What to do about out-of-memory errors?

  1. Reduce the number of parallel chains:

    hc.init(4)  # Instead of 8
    
  2. Reduce the number of samples:

    mcmc.run(num_samples=1000)  # Instead of 5000
    
  3. Be especially careful when using 64-bit precision:

    # 32-bit precision can halve memory usage
    os.environ['JAX_ENABLE_X64'] = 'False'
    

Summary

Scenario

Recommended Configuration

Quick testing

hc.init(1) or do not call

Daily use

hc.init() auto-configure

Production run

hc.init(8) 8-core parallelism

GPU acceleration

hc.init("GPU")

Cluster submission

hc.init(N) at the beginning of the script

Remember: ``hc.init()`` must be called before importing JAX!