并行与加速配置
HIcosmo 基于 JAX 构建,原生支持多核 CPU 并行和 GPU 加速。 本章详细介绍如何配置并行环境,充分发挥硬件性能。
快速开始
一行代码配置并行:
import hicosmo as hc
# 最常用:8 核并行
hc.init(8)
# 自动检测最优配置
hc.init()
# GPU 模式
hc.init("GPU")
重要
必须在导入 JAX 之前调用 hc.init()!
JAX 的设备配置在首次导入时就会固定,之后无法更改。
因此 hc.init() 应该是你脚本的第一行代码。
# ✅ 正确顺序
import hicosmo as hc
hc.init(8)
from hicosmo.models import LCDM # 此时 JAX 被导入
from hicosmo.samplers import MCMC
# ❌ 错误顺序
from hicosmo.models import LCDM # JAX 已导入,设备数固定为 1
import hicosmo as hc
hc.init(8) # ⚠️ 警告:配置被忽略
初始化 API
hc.init() 函数
import hicosmo as hc
hc.init(
num_devices='auto', # 设备数量
device=None, # 设备类型 ('GPU' 或 None)
verbose=True, # 是否显示配置信息
force=True # 是否强制覆盖现有配置
)
参数说明:
参数 |
类型 |
说明 |
|---|---|---|
|
int / str / None |
并行设备数量:
-
int:指定数量(如 8)-
'auto':自动检测(默认,最多 8)-
'GPU':GPU 模式-
None:单设备向量化模式 |
|
str / None |
显式设备类型:
-
None:CPU 模式(默认)-
'GPU' / 'cuda':GPU 模式 |
|
bool |
是否打印配置信息 |
|
bool |
是否覆盖已存在的 XLA_FLAGS |
多核 CPU 并行
基本配置
import hicosmo as hc
# 使用 8 个 CPU 核心
hc.init(8)
# 运行 MCMC 时会自动使用 8 条并行链
from hicosmo.samplers import MCMC
mcmc = MCMC(params, likelihood, chain_name='test')
mcmc.run(num_samples=2000, num_chains=8) # 8 条链并行运行
工作原理
当你调用 hc.init(8) 时,HIcosmo 会:
设置 XLA_FLAGS:配置 JAX 使用 8 个逻辑 CPU 设备
配置线程池:将 CPU 核心分配给各设备
设置 NumPyro:告知 NumPyro 使用 8 个设备
┌─────────────────────────────────────────────────────────┐
│ CPU (8 核) │
├─────────┬─────────┬─────────┬─────────┬─────────────────┤
│ Device 0│ Device 1│ Device 2│ Device 3│ ... Device 7 │
│ (Chain 1)│ (Chain 2)│ (Chain 3)│ (Chain 4)│ ... (Chain 8)│
└─────────┴─────────┴─────────┴─────────┴─────────────────┘
推荐配置
场景 |
推荐配置 |
说明 |
|---|---|---|
个人笔记本(4 核) |
|
充分利用所有核心 |
工作站(8-16 核) |
|
8 条链足够收敛诊断 |
服务器(32+ 核) |
|
更多链不一定更好 |
调试/开发 |
|
单设备更容易调试 |
备注
为什么推荐 8 条链?
足够计算 R-hat 收敛诊断(需要 ≥4 条链)
更多链不会显著提高采样效率
过多链可能导致内存压力
GPU 加速
基本配置
import hicosmo as hc
# 自动检测 GPU
hc.init("GPU")
# 或显式指定
hc.init(device="GPU")
多 GPU 配置
import hicosmo as hc
# 使用 4 个 GPU(如果有)
hc.init(4, device="GPU")
# 自动检测所有可用 GPU
hc.init("GPU") # 自动检测 GPU 数量
GPU 与 CPU 的选择
场景 |
推荐 |
原因 |
|---|---|---|
简单似然(SNe, BAO) |
CPU |
GPU 启动开销大于计算收益 |
复杂似然(CMB 功率谱) |
GPU |
大规模矩阵运算受益于 GPU |
大规模 Fisher 矩阵 |
GPU |
矩阵求逆适合 GPU |
长时间 MCMC(>10万样本) |
GPU |
累积收益明显 |
GPU 环境准备
确保已安装 JAX GPU 版本:
# CUDA 12
pip install jax[cuda12]
# CUDA 11
pip install jax[cuda11_pip]
验证 GPU 可用:
import jax
print(jax.devices())
# [cuda(id=0), cuda(id=1), ...] # GPU 可用
# [CpuDevice(id=0)] # 仅 CPU
集群配置
SLURM 集群
在 SLURM 集群上运行 HIcosmo:
#!/bin/bash
#SBATCH --job-name=hicosmo
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=32G
#SBATCH --time=24:00:00
# 激活环境
source activate hicosmo
# 运行脚本
python my_mcmc_script.py
对应的 Python 脚本:
import hicosmo as hc
# 使用 SLURM 分配的 8 个核心
hc.init(8)
# ... 其余代码
PBS/Torque 集群
#!/bin/bash
#PBS -N hicosmo
#PBS -l nodes=1:ppn=8
#PBS -l mem=32gb
#PBS -l walltime=24:00:00
cd $PBS_O_WORKDIR
source activate hicosmo
python my_mcmc_script.py
GPU 集群
#!/bin/bash
#SBATCH --job-name=hicosmo_gpu
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --mem=64G
#SBATCH --time=24:00:00
source activate hicosmo
python my_mcmc_script.py
import hicosmo as hc
# 使用 4 个 GPU
hc.init(4, device="GPU")
环境变量配置
HIcosmo 自动管理以下环境变量,通常不需要手动设置:
环境变量 |
说明 |
|---|---|
|
JAX/XLA 编译器配置,包括设备数 |
|
启用 64 位精度(默认 True) |
|
每设备线程数 |
手动配置(高级)
如果需要更精细的控制:
# 在 shell 中设置
export XLA_FLAGS="--xla_force_host_platform_device_count=8"
export JAX_ENABLE_X64=True
python my_script.py
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
# 必须在设置环境变量后、导入 JAX 前
import hicosmo as hc
hc.init(8, force=False) # 不覆盖已设置的 XLA_FLAGS
查看当前配置
import hicosmo as hc
hc.init(8)
# 查看配置状态
status = hc.Config.status()
print(status)
# {
# 'initialized': True,
# 'config': {'num_devices': 8, 'device_type': 'cpu', ...},
# 'system_cores': 16,
# 'jax_devices': 8,
# 'jax_device_list': ['CpuDevice(id=0)', 'CpuDevice(id=1)', ...]
# }
# 查看 JAX 设备
import jax
print(f"可用设备: {len(jax.devices())}")
for d in jax.devices():
print(f" {d}")
性能调优建议
匹配链数与设备数
hc.init(8) # num_chains 应该等于或小于设备数 mcmc.run(num_samples=2000, num_chains=8) # ✅ 最优 mcmc.run(num_samples=2000, num_chains=4) # ✅ 可以,但浪费 4 个设备 mcmc.run(num_samples=2000, num_chains=16) # ⚠️ 超过设备数,会排队
预热 JIT 编译
# 第一次调用会触发 JIT 编译(较慢) result = likelihood(H0=70, Omega_m=0.3) # 后续调用使用缓存(很快) result = likelihood(H0=71, Omega_m=0.3)
内存管理
# 大规模采样时监控内存 import jax jax.clear_caches() # 清理 JIT 缓存释放内存
避免频繁初始化
# ❌ 不要在循环中初始化 for i in range(10): hc.init(8) # 每次都会打印警告 ... # ✅ 只初始化一次 hc.init(8) for i in range(10): ...
常见问题
Q: 为什么配置被忽略了?
最常见原因是 JAX 已被导入:
from hicosmo.models import LCDM # ❌ JAX 在这里被导入
import hicosmo as hc
hc.init(8) # ⚠️ 警告:配置被忽略
解决方法:确保 hc.init() 在所有 JAX 相关导入之前。
Q: GPU 未被检测到?
检查 JAX GPU 版本是否安装:
import jax print(jax.devices()) # 应该显示 cuda 设备
检查 CUDA 环境:
nvidia-smi # 应该显示 GPU 信息
重新安装 JAX GPU 版本:
pip uninstall jax jaxlib pip install jax[cuda12]
Q: 如何在 Jupyter Notebook 中使用?
在 notebook 的第一个单元格:
# Cell 1 - 必须是第一个单元格!
import hicosmo as hc
hc.init(8)
# Cell 2 - 然后导入其他模块
from hicosmo.models import LCDM
from hicosmo.samplers import MCMC
警告
如果重启 kernel,需要重新运行 hc.init()。
Q: 内存不足怎么办?
减少并行链数:
hc.init(4) # 而不是 8
减少样本数:
mcmc.run(num_samples=1000) # 而不是 5000
使用 64 位精度时尤其注意:
# 32 位精度可以减少一半内存 os.environ['JAX_ENABLE_X64'] = 'False'
小结
场景 |
推荐配置 |
|---|---|
快速测试 |
|
日常使用 |
|
生产运行 |
|
GPU 加速 |
|
集群提交 |
在脚本开头 |
记住:``hc.init()`` 必须在导入 JAX 之前调用!