JAX 技术详解
============
本章详细介绍 JAX 的核心概念、技术优势,以及 HIcosmo 选择 JAX 作为计算引擎的原因。
如果你对 JAX 不熟悉,本章将帮助你理解为什么 JAX 是现代科学计算的革命性工具。
什么是 JAX?
------------
JAX 是 Google 开发的一个高性能数值计算库,可以理解为 **"可微分、可编译、可并行的 NumPy"**。
它的名字来源于 **J**\ ust-in-time compilation + **A**\ utomatic differentiation + **X**\ LA (加速线性代数)。
**官方定义**:
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
**通俗解释**:
- 如果你会用 NumPy,你就会用 JAX(API 几乎相同)
- JAX 可以自动计算任何函数的导数(包括高阶导数)
- JAX 可以自动将代码编译为高效的机器码
- JAX 可以透明地在 CPU、GPU、TPU 上运行
JAX 被视为一种 **可微分编程语言**,它将 Python 作为构建 XLA 计算图的"元编程语言"。
函数式编程:JAX 的核心哲学
--------------------------
JAX 的设计核心在于 **函数式编程(Functional Programming, FP)**,这与传统的面向对象编程(OOP)框架(如 PyTorch)有着本质的区别。
为什么选择函数式编程?
~~~~~~~~~~~~~~~~~~~~~~
JAX 采用函数式编程的主要原因包括:
1. **数学一致性**:自动微分在数学上本质上是函数的变换。函数式编程允许 JAX 将 ``grad``(求导)、``jit``(编译)、``vmap``(向量化)作为高阶函数自由嵌套和组合,例如:``jit(vmap(grad(f)))``
2. **便于编译优化**:XLA 编译器需要一个静态的、无副作用的计算图来执行算子融合和内存优化。面向对象中的状态突变(如修改类成员变量)会破坏这种静态图的构建
3. **确定性随机状态**:传统 OOP 框架常使用全局随机状态,而 JAX 强制要求显式传递随机数生成器(PRNG)密钥,确保代码的可复现性
4. **状态与逻辑分离**:在函数式范式下,模型参数作为函数的输入参数传入,而不是存储在对象内部,使得参数分发和更新更加透明
纯函数 (Pure Functions)
~~~~~~~~~~~~~~~~~~~~~~~
纯函数是 JAX 能够实现高性能计算的基石。
**定义**:纯函数是指 **不改变外部状态** 且 **没有副作用** 的函数。
**特性**:对于相同的输入,纯函数 **必须始终产生相同的输出**。
**为什么重要**:纯函数的可预测性允许 JAX 的 XLA 编译器对代码进行深度优化、实现即时编译(JIT)以及在 GPU/TPU 上轻松实现算子并行和分片。
.. code-block:: python
import jax.numpy as jnp
# ✅ 纯函数:相同输入 → 相同输出,无副作用
def pure_function(x, y):
return jnp.sin(x) + jnp.cos(y)
# ❌ 非纯函数:依赖外部状态
global_state = 0
def impure_function(x):
global global_state
global_state += 1 # 副作用:修改外部状态
return x + global_state
不可变数组 (Immutable Arrays)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
在 JAX 中,数组一旦创建就 **不能被修改**。
**与 NumPy 的区别**:传统 NumPy 允许原地(in-place)更新数组,JAX 数组是不可变的。
**操作方式**:如果需要修改数组中的某个元素,JAX 不会直接覆盖原内存,而是创建一个包含修改后内容的新数组。
.. code-block:: python
import numpy as np
import jax.numpy as jnp
# NumPy:原地修改
np_arr = np.array([1, 2, 3])
np_arr[0] = 99 # ✅ 可以直接修改
# JAX:不可变,需要创建新数组
jax_arr = jnp.array([1, 2, 3])
# jax_arr[0] = 99 # ❌ TypeError!
jax_arr = jax_arr.at[0].set(99) # ✅ 创建新数组
**设计原因**:不可变性解锁了编译优化,使 XLA 能够安全地进行算子融合和内存重用。
梯度计算对比:PyTorch vs JAX
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
以下是 PyTorch(面向对象)与 JAX(函数式)在处理梯度时的核心区别:
**PyTorch (面向对象范式)**:
梯度被视为张量对象的一个属性,通过修改状态来实现。
.. code-block:: python
import torch
def f(x):
return x**2 + 3*x + 5
x = torch.tensor([2.0], requires_grad=True)
loss = f(x)
loss.backward() # 触发 AutoGrad 并在计算图中回溯
# 梯度被"存储"在变量 x 的 .grad 属性中(状态改变)
print(x.grad.item()) # 输出: 7.0
**JAX (函数式范式)**:
求导是一个 **函数变换**,不会改变任何输入或对象的状态。
.. code-block:: python
import jax
def f(x):
return x**2 + 3*x + 5
# jax.grad 接收一个函数并返回另一个函数(函数变换)
df_dx = jax.grad(f)
x_value = 2.0
# 梯度直接作为返回值返回,不修改原变量
derivative_value = df_dx(x_value)
print(derivative_value) # 输出: 7.0
**对比总结**:
.. list-table::
:header-rows: 1
:widths: 25 35 40
* - 特性
- 面向对象 (PyTorch)
- 函数式 (JAX)
* - **状态管理**
- 状态存储在对象内部,允许原地修改
- 纯函数,状态通过参数传递,不可变
* - **梯度处理**
- 通过 ``loss.backward()`` 存储在 ``.grad`` 属性
- 使用 ``jax.grad`` 变换函数,直接返回梯度值
* - **随机数**
- 依赖隐式全局随机状态
- 必须显式传递 PRNG 密钥
* - **底层核心**
- 动态计算图,易于调试
- XLA 编译的静态计算图,追求极致性能
JAX 的四大核心特性
------------------
1. NumPy 兼容 API
~~~~~~~~~~~~~~~~~
JAX 提供与 NumPy 几乎完全兼容的 API:
.. code-block:: python
# NumPy 代码
import numpy as np
x = np.array([1.0, 2.0, 3.0])
y = np.sin(x) + np.cos(x)
z = np.dot(x, y)
# JAX 代码(只需改变 import)
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x) + jnp.cos(x)
z = jnp.dot(x, y)
**主要区别**:
.. list-table::
:header-rows: 1
:widths: 20 40 40
* - 特性
- NumPy
- JAX
* - **可变性**
- 数组可变 (``x[0] = 5``)
- 数组不可变 (``x = x.at[0].set(5)``)
* - **随机数**
- 全局状态 (``np.random.rand()``)
- 显式 key (``jax.random.uniform(key)``)
* - **执行模式**
- 即时执行
- 可延迟执行(traced)
* - **硬件支持**
- 仅 CPU
- CPU / GPU / TPU
2. 自动微分 (Automatic Differentiation)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
JAX 的自动微分是其最强大的特性之一。它可以自动计算任意复杂函数的精确导数。
**什么是自动微分?**
在科学计算中,我们经常需要计算函数的导数(梯度)。传统方法有三种:
1. **符号微分**:用数学公式手动推导,容易出错,无法处理复杂函数
2. **数值微分**:用有限差分近似,:math:`f'(x) \approx \frac{f(x+h) - f(x)}{h}`,精度低,效率低
3. **自动微分**:通过链式法则自动追踪计算图,精确高效
**JAX 的自动微分示例**:
.. code-block:: python
import jax
import jax.numpy as jnp
# 定义一个复杂函数
def f(x):
return jnp.sin(x) * jnp.exp(-x**2) + jnp.log(1 + x**2)
# 自动获取导数函数
df = jax.grad(f) # 一阶导数
d2f = jax.grad(df) # 二阶导数
d3f = jax.grad(d2f) # 三阶导数
# 计算导数值
x = 1.0
print(f"f(x) = {f(x):.6f}")
print(f"f'(x) = {df(x):.6f}")
print(f"f''(x) = {d2f(x):.6f}")
print(f"f'''(x)= {d3f(x):.6f}")
**向量函数的 Jacobian 和 Hessian**:
.. code-block:: python
# Hessian 矩阵(用于优化和 Fisher 矩阵)
def scalar_func(x):
return jnp.sum(x**2)
hessian = jax.hessian(scalar_func)
**自动微分的优势**:
- **精确**:结果是精确的数学导数,不是数值近似
- **高效**:复杂度与原函数相同,不需要多次函数评估
- **通用**:支持任意复杂的函数组合,包括条件语句和循环
3. JIT 编译与 XLA 编译器
~~~~~~~~~~~~~~~~~~~~~~~~
JIT 编译是 JAX 性能的核心来源。它将 Python 代码编译为优化的机器码。
**XLA 编译器工作原理**
XLA(加速线性代数,Accelerated Linear Algebra)是一个专门针对线性代数运算的编译器,它将 JAX 编写的程序转化为针对特定硬件(CPU、GPU 或 TPU)优化过的可执行内核。
工作流程分为以下几个阶段:
1. **追踪 (Tracing)**:当你对一个 Python 函数应用 ``jax.jit`` 并首次调用它时,JAX 会通过"追踪"机制记录该函数在处理具有特定形状和类型的数组时的操作序列
2. **构建计算图**:追踪得到的信息会被转化为 XLA 计算图(HLO,高级优化器表示)
3. **硬件编译与优化**:XLA 编译器会对该图进行全方位的分析,执行包括死代码消除、算子融合等优化策略,最后将其编译为直接在硬件加速器上运行的二进制机器码
**算子融合 (Operator Fusion)**
算子融合是 XLA 最具杀伤力的优化技术之一。
**原理**:在传统的向量化编程中,每一步简单操作(如 ``A * B + C``)通常会产生大量的中间临时数组。这些中间结果需要写回显存然后再读入,造成了严重的内存带宽浪费。
**实现方式**:XLA 会在计算图中识别出连续的操作,并将其"融合"为一个单一的 GPU 内核。这意味着中间计算结果会直接保留在处理器的快速寄存器或缓存中,而不是频繁地与主存交换数据。
**收益**:显著降低了内存带宽占用并减少了内存分配压力,在处理高维宇宙学数据时尤其能避免 OOM(内存溢出)错误。
**比喻说明**:
你可以把传统的 Python 运算想象成快餐店点餐:你每点一个汉堡,服务员都要跑回厨房做(调用 C 库),做完拿给你,你再点一根薯条,服务员再跑一次。
而 XLA 编译器和 JIT 就像是一个智能主厨:你告诉他你要点一个套餐(整个函数),他会分析菜单,决定同时炸薯条和煎肉饼(算子融合),最后一次性把热气腾腾的完整套餐端给你。
这种"一站式"处理极大地减少了路途上的等待时间(Python 开销和内存读写)。
**JIT 编译示例**:
.. code-block:: python
import jax
import jax.numpy as jnp
import time
# 定义一个计算密集型函数
def slow_function(x):
for _ in range(100):
x = jnp.sin(x) + jnp.cos(x)
return x
# JIT 编译版本
fast_function = jax.jit(slow_function)
x = jnp.ones(10000)
# 第一次调用:包含编译时间
start = time.time()
_ = fast_function(x)
print(f"第一次调用(含编译): {time.time() - start:.4f}s")
# 后续调用:使用缓存的编译结果
start = time.time()
for _ in range(100):
_ = fast_function(x)
print(f"后续100次调用: {time.time() - start:.4f}s")
4. 向量化 (Vectorization with vmap)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``vmap`` 是 JAX 的向量化变换,可以自动将标量函数转换为批量函数。
**示例**:
.. code-block:: python
import jax
import jax.numpy as jnp
# 定义一个处理单个样本的函数
def single_sample_loglike(theta, data_point):
mu, sigma = theta
return -0.5 * ((data_point - mu) / sigma)**2 - jnp.log(sigma)
# 使用 vmap 自动批量化
batch_loglike = jax.vmap(single_sample_loglike, in_axes=(None, 0))
theta = (0.0, 1.0)
data = jnp.array([0.1, 0.5, -0.3, 0.8, -0.2])
# 一次性计算所有数据点的对数似然
log_likes = batch_loglike(theta, data)
**性能对比** (1000 个数据点):
- Python 循环:~50ms
- NumPy 向量化:~2ms
- JAX vmap + JIT:~0.1ms
PRNG 随机数管理
---------------
JAX 的伪随机数生成器(PRNG)管理机制是其函数式编程哲学的核心体现。与 NumPy 或 PyTorch 等传统框架使用的"全局随机状态"不同,JAX 采用的是一种 **显式、无状态且可分支** 的随机数管理系统。
核心机制:显式密钥与分支
~~~~~~~~~~~~~~~~~~~~~~~~
在 JAX 中,随机性不是由一个随时间变化的全局种子(Seed)控制的,而是通过一个 **随机密钥(PRNGKey)** 来管理的。
- **PRNGKey**:这是一个包含随机状态的数组对象。所有的随机函数(如 ``jax.random.normal``)都需要接收一个 Key 作为输入,并根据该 Key 生成确定的输出
- **密钥分支(Splitting)**:为了生成多个不相关的随机数流,JAX 使用 ``jax.random.split`` 函数将一个主 Key 分解为多个子 Key。这类似于树状结构:从一个根种子开始,不断分叉出独立的随机支流
为什么需要显式传递随机密钥?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
JAX 要求显式传递密钥,根本原因在于其对 **纯函数** 的追求:
1. **消除副作用**:传统的全局随机状态本质上是一个"全局变量"。每次调用随机函数都会改变这个全局状态,这属于"副作用"。为了实现 JIT 和算子融合,函数必须是"纯"的
2. **可复现性**:显式密钥确保了无论计算在何处、以何种顺序执行,只要 Key 相同,结果就完全一致
3. **并行化与矢量化**:在大规模并行计算(如 ``pmap`` 或 ``vmap``)中,全局状态会导致竞争和难以预料的结果。显式密钥使得 PRNG 的生成过程可以轻松地在多个硬件设备上进行并行化
**对比**:
.. list-table::
:header-rows: 1
:widths: 25 35 40
* - 特性
- 传统框架 (PyTorch/NumPy)
- JAX
* - **状态管理**
- 隐式、全局。存在一个后台全局状态
- 显式、局部。密钥作为参数在函数间传递
* - **副作用**
- 有。每次调用随机函数都会改变全局状态
- 无。调用函数不改变输入密钥,只返回结果
* - **并行安全**
- 困难。并行环境下的全局状态同步复杂
- 原生支持。通过 Split 为每个分支提供独立 Key
* - **确定性**
- 依赖于执行顺序
- 仅依赖于传递的 Key,与执行顺序无关
代码示例
~~~~~~~~
.. code-block:: python
import jax
import jax.numpy as jnp
# 1. 创建初始密钥
key = jax.random.PRNGKey(42)
# 2. 如果直接多次使用同一个 key,会得到相同的结果
val1 = jax.random.normal(key)
val2 = jax.random.normal(key)
assert jnp.all(val1 == val2) # 相同!
# 3. 正确做法:使用 split 分支密钥
key, subkey = jax.random.split(key)
random_val = jax.random.normal(subkey)
# 4. 生成多个独立的随机数
key, *subkeys = jax.random.split(key, 5)
random_vals = [jax.random.normal(k) for k in subkeys]
**形象比喻**:
传统的随机生成器像一个 **"自动取款机"**,你每取一次钱,银行后台的余额(全局状态)就会减少,你无法回到取钱前的状态。
而 JAX 的 PRNG 机制像是一张 **"可复制的藏宝图"**,密钥就是地图上的坐标。如果你把地图复制一份(Split)给另一个人,你们按照相同的坐标和步骤走,最终找到的宝藏(随机数)必然是一模一样的。
这使得在大规模分布式搜索(并行计算)中,每个人都能清晰地知道自己负责哪一块区域而不会发生混乱。
JAX 在宇宙学中的应用
--------------------
JAX 在宇宙学计算中的应用正在引发一场范式革命,推动该领域从传统的"数值积分驱动"向 **"全流程可微分推断"(Differentiable Universe)** 转变。
具体应用领域
~~~~~~~~~~~~
JAX 通过其自动微分和 GPU 加速能力,渗透到了宇宙学研究的各个核心环节:
1. **高维贝叶斯推断**:在处理下一代天文观测数据(如 Stage IV 巡天)时,模型参数往往超过 150 个。JAX 使整个似然函数变得完全可导,支持高效的高维空间探索
2. **宇宙学预测仿真器(Emulators)**:使用基于 JAX 的神经网络(如 CosmoPower-JAX)替代传统的玻尔兹曼求解器(如 CAMB/CLASS)。这些仿真器能以亚毫秒级的速度预测物质功率谱等观测效应
3. **可微分宇宙模拟**:JAX 被用于编写完全可微的 N 体模拟和引力透镜模拟(如 JAXtronomy)。这意味着研究人员可以进行"场级推断"(Field-level inference)
4. **Fisher 矩阵计算**:传统方法依赖不稳定的数值差分。JAX 可以通过 ``jax.hessian`` 自动计算精确的二阶导数
如何加速参数推断和 MCMC 采样?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
JAX 的加速并非单一维度的提升,而是通过硬件优化和算法改进的结合实现的:
- **从"随机盲走"到"梯度导航"**:传统的 MCMC(如 Metropolis-Hastings)在参数空间中随机试探,效率随维度增加而急剧下降。JAX 允许使用 **哈密顿蒙特卡罗(HMC)** 或 NUTS 采样器,利用梯度引导采样器沿着高概率区域"滑行"
- **XLA 编译器与算子融合**:通过 ``jax.jit``,XLA 编译器将 Python 代码转化为针对 GPU/TPU 优化的机器码,消除 Python 的解释延迟
- **神经网络加速**:像 CosmoPower-JAX 这样的神经网络仿真器本质上是线性代数操作,这正是 GPU 所擅长的
**比喻理解**:
传统宇宙学采样就像是在浓雾笼罩的山谷中蒙眼摸路(随机撒点),每走一步都要耗费大量体力去试探深度。
而 JAX 就像是为研究者提供了一套 **GPS 导航系统和强力聚光灯**,梯度信息指明了通往山顶的最快路径,而 GPU 硬件加速则像是为研究者换上了高速滑板,让原本需要数年才能走完的路程在几天内即可完成。
宇宙学 JAX 项目生态系统
~~~~~~~~~~~~~~~~~~~~~~~
目前,宇宙学界已经形成了丰富的 JAX 软件生态系统:
.. list-table::
:header-rows: 1
:widths: 20 25 55
* - 项目名称
- 核心用途
- 特点
* - **jax-cosmo**
- 核心宇宙学库
- 提供可微分的背景演化、功率谱和似然函数计算
* - **CosmoPower-JAX**
- 神经网络仿真器
- 高速预测 CMB 和物质功率谱,支持几百个参数的 HMC 采样
* - **candl**
- CMB 似然分析
- 专门用于分析 CMB 功率谱测量(如 SPT、ACT)
* - **JAXtronomy**
- 引力透镜模拟
- lenstronomy 的 JAX 移植版,支持 GPU 加速
* - **microJAX**
- 微引力透镜模拟
- 首个完全可微的微透镜建模框架
* - **DISCO-DJ**
- 可微 N 体模拟
- 实现完全自动微分的宇宙学模拟
* - **PyBird-JAX**
- 有效场论 (EFT)
- 用于快速处理 LSS 数据的 EFT 预测
性能数据与基准测试
------------------
根据研究数据,JAX/XLA 展示了远超传统框架的加速能力:
宇宙学推断加速
~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
:widths: 40 25 25 10
* - 任务
- 传统方法
- JAX 方法
- 加速比
* - 高维推断(157 参数)
- 12 年(48 核 CPU,嵌套采样)
- 8 天(24 GPU,梯度采样)
- **10^5 倍**
* - 宇宙剪切分析(37 参数)
- 数周
- 数小时
- **~1000 倍**
* - 智利央行经济模型
- 12 小时(工业服务器)
- 几秒(消费级 GPU)
- **~1000 倍**
科学计算加速
~~~~~~~~~~~~
.. list-table::
:header-rows: 1
:widths: 50 20 20 10
* - 操作
- NumPy/SciPy
- JAX JIT
- 加速比
* - 超新星似然函数计算
- ~50 ms
- ~0.05 ms
- **1000 倍**
* - 距离计算 (1000 点)
- 150 ms
- 20 ms
- **7.5 倍**
* - JAXtronomy 射线追踪 (vs CPU)
- 基准
- GPU 加速
- **120-140 倍**
* - 求解万维 PDE
- 基准
- JAX 优化
- **1000 倍 + 30x 内存节省**
为什么 HIcosmo 选择 JAX?
-------------------------
1. MCMC 采样需要高效的梯度计算
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
现代 MCMC 采样器(如 NUTS)依赖于哈密顿蒙特卡洛方法,需要在每一步计算参数的梯度。
**传统数值微分的问题**:
- 对于 10 个参数:需要 20 次函数评估
- 对于 100 个参数:需要 200 次函数评估
- 精度受 epsilon 选择影响
**JAX 自动微分的优势**:
- 无论多少参数,只需要 ~2 次等效函数评估
- 结果是精确的数学导数
- NumPyro NUTS 采样器直接使用 JAX 的自动微分
2. 宇宙学计算需要高精度数值积分
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
宇宙学距离计算涉及积分:
.. math::
d_C(z) = \frac{c}{H_0} \int_0^z \frac{dz'}{E(z')}
**HIcosmo 方法**:预计算高精度积分表,JIT 编译后极快
.. code-block:: python
from hicosmo.models import LCDM
import jax.numpy as jnp
z_grid = jnp.linspace(0.01, 2.0, 1000)
params = {'H0': 70.0, 'Omega_m': 0.3}
# 第一次调用:编译 ~100ms
grid = LCDM.compute_grid_traced(z_grid, params)
# 后续调用:使用缓存 ~0.1ms
grid = LCDM.compute_grid_traced(z_grid, params)
3. Fisher 矩阵需要二阶导数
~~~~~~~~~~~~~~~~~~~~~~~~~~
Fisher 矩阵是似然函数的 Hessian 矩阵的期望值:
.. math::
F_{ij} = -\left\langle \frac{\partial^2 \ln L}{\partial \theta_i \partial \theta_j} \right\rangle
**传统方法**:数值二阶导数需要 :math:`O(n^2)` 次函数评估且不稳定
**JAX 方法**:
.. code-block:: python
import jax
# 自动获取 Hessian 矩阵
hessian = jax.hessian(log_likelihood)
# Fisher 矩阵
fisher_matrix = -hessian(best_fit_params)
JAX vs 其他框架
---------------
与 NumPy 对比
~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
:widths: 25 35 40
* - 特性
- NumPy
- JAX
* - **API**
- 原生 Python
- 与 NumPy 几乎相同
* - **自动微分**
- 不支持
- 完整支持
* - **JIT 编译**
- 不支持
- 支持
* - **GPU 支持**
- 需要 CuPy
- 透明支持
* - **学习曲线**
- 低
- 低(会 NumPy 即可)
与 TensorFlow/PyTorch 对比
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. list-table::
:header-rows: 1
:widths: 20 25 25 30
* - 特性
- TensorFlow
- PyTorch
- JAX
* - **设计目标**
- 深度学习
- 深度学习
- 通用科学计算
* - **API 风格**
- 计算图
- 动态图
- 函数变换
* - **科学计算友好**
- 一般
- 一般
- **优秀**
* - **与 NumPy 兼容**
- 部分
- 部分
- **完全**
* - **函数式编程**
- 有限
- 有限
- **核心设计**
**为什么 HIcosmo 不用 TensorFlow/PyTorch?**
1. **API 复杂度**:TensorFlow/PyTorch 的 API 针对神经网络设计,对于纯数值计算过于复杂
2. **NumPy 兼容性**:JAX 可以直接使用 NumPy 风格的代码,迁移成本低
3. **函数式编程**:宇宙学计算是纯函数,JAX 的函数变换理念完美匹配
4. **编译效率**:XLA 编译器对数值计算有专门优化
JAX 常见陷阱与解决方案
----------------------
1. 数组不可变
~~~~~~~~~~~~~
.. code-block:: python
# ❌ 错误:JAX 数组不可变
x = jnp.array([1, 2, 3])
x[0] = 5 # TypeError!
# ✅ 正确:使用 .at[].set()
x = x.at[0].set(5)
2. 随机数需要显式 key
~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
# ❌ 错误:没有全局随机状态
x = jax.random.normal() # Error!
# ✅ 正确:显式传递 key
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape=(10,))
# 生成新的 key
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(10,))
3. JIT 中的动态形状
~~~~~~~~~~~~~~~~~~~
.. code-block:: python
# ❌ 错误:JIT 函数中使用动态形状
@jax.jit
def bad_func(x):
return jnp.zeros(len(x)) # len(x) 在编译时未知
# ✅ 正确:使用 x.shape
@jax.jit
def good_func(x):
return jnp.zeros(x.shape)
4. Python 控制流
~~~~~~~~~~~~~~~~
.. code-block:: python
# ❌ 错误:JIT 中的 Python if
@jax.jit
def bad_func(x, flag):
if flag: # Python 条件
return x + 1
return x
# ✅ 正确:使用 jnp.where
@jax.jit
def good_func(x, flag):
return jnp.where(flag, x + 1, x)
HIcosmo 的 JAX 使用模式
-----------------------
基本使用
~~~~~~~~
.. code-block:: python
# HIcosmo 对 JAX 进行了封装,用户通常不需要直接使用 JAX
from hicosmo import hicosmo
# 这行代码背后使用了 JAX 的:
# - JIT 编译加速距离计算
# - 自动微分提供 NUTS 采样器的梯度
# - vmap 并行化多条 MCMC 链
inf = hicosmo("LCDM", "sn", ["H0", "Omega_m"])
samples = inf.run()
高级使用
~~~~~~~~
.. code-block:: python
import jax
import jax.numpy as jnp
from hicosmo.models import LCDM
# 直接使用 JAX 的自动微分
def my_likelihood(params):
grid = LCDM.compute_grid_traced(z_obs, params)
d_L = grid['d_L']
mu_theory = 5 * jnp.log10(d_L) + 25
chi2 = jnp.sum(((mu_obs - mu_theory) / mu_err)**2)
return -0.5 * chi2
# 自动获取梯度
grad_likelihood = jax.grad(my_likelihood)
# 自动获取 Hessian(用于 Fisher 矩阵)
hessian_likelihood = jax.hessian(my_likelihood)
学习资源
--------
**官方文档**:
- `JAX 官方文档 `_
- `JAX 快速入门 `_
- `JAX 常见陷阱 `_
**推荐教程**:
- `Thinking in JAX `_
- `JAX 101 `_
**社区资源**:
- `Awesome JAX `_:JAX 资源合集
- `NumPyro `_:基于 JAX 的概率编程库(HIcosmo 使用)
**宇宙学 JAX 项目**:
- `jax-cosmo `_:可微分宇宙学
- `CosmoPower-JAX `_:神经网络仿真器
- `JAXtronomy `_:引力透镜模拟
总结
----
JAX 为 HIcosmo 提供了以下核心能力:
1. **自动微分**:NUTS 采样器无需手动梯度,支持任意复杂似然函数
2. **JIT 编译**:10-1000 倍性能提升,XLA 算子融合消除内存瓶颈
3. **向量化**:vmap 实现高效批量计算和多链并行
4. **GPU 支持**:代码无需修改,透明加速
5. **可组合性**:grad/jit/vmap 可以任意组合使用
6. **函数式编程**:纯函数设计确保可复现性和并行安全
通过这种函数式架构,JAX 能够将复杂的物理或数学模型转化为可优化的对象,从而在宇宙学等高维推断任务中实现相较于传统方法 **10^3 到 10^5 倍** 的性能飞跃。
对于用户来说,HIcosmo 已经封装了 JAX 的复杂性。你可以像使用普通 Python 库一样使用 HIcosmo,同时享受 JAX 带来的性能提升。只有在需要自定义似然函数或高级分析时,才需要直接使用 JAX API。
下一步
------
- `核心概念 `_:了解 HIcosmo 的架构设计
- `采样器 `_:了解 MCMC 配置
- `Fisher 预测 `_:了解 Fisher 矩阵分析