JAX 数值工具参考
hicosmo.utils.jax_tools 模块提供了一套完整的 JAX 优化数值工具,包括积分、微分、求根和ODE求解器。
所有函数都支持 JIT 编译和自动微分。
快速开始
from hicosmo.utils import (
integrate_simpson, # Simpson积分
gauss_legendre_integrate, # Gauss-Legendre积分
cumulative_trapezoid, # 累积梯形积分
gradient_1d, # 一维梯度
newton_root, # 牛顿法求根
odeint_rk4, # RK4 ODE求解器
)
积分工具
梯形积分
- trapezoid(y, x, axis=-1)
JAX 梯形积分包装器。
- 参数:
y – 函数值数组
x – 自变量数组
axis – 积分轴
- 返回:
积分结果
示例:
import jax.numpy as jnp from hicosmo.utils import trapezoid x = jnp.linspace(0, 1, 100) y = x**2 result = trapezoid(y, x) # ≈ 1/3
Simpson积分
- simpson(y, x)
复合Simpson积分,用于离散采样数据。
- 参数:
y – 函数值数组(至少3个点)
x – 自变量数组
- 返回:
积分结果
- 抛出:
ValueError – 如果点数少于3个
- integrate_simpson(func, a, b, *, num=512)
在区间 [a, b] 上使用Simpson法则积分函数。
- 参数:
func – 被积函数 f(x)
a – 积分下限
b – 积分上限
num – 采样点数(默认512,自动调整为偶数)
- 返回:
积分结果
示例:
from hicosmo.utils import integrate_simpson import jax.numpy as jnp # 计算 ∫₀¹ x² dx = 1/3 result = integrate_simpson(lambda x: x**2, 0, 1) print(f"Result: {result:.6f}") # ≈ 0.333333
对数空间积分
- integrate_logspace(func, k_min, k_max, *, num=512)
在对数空间网格上积分,适用于正定义域的幂律函数。
- 参数:
func – 被积函数 f(k)
k_min – 积分下限(必须 > 0)
k_max – 积分上限
num – 采样点数
- 返回:
积分结果
示例:
from hicosmo.utils import integrate_logspace # 对幂律函数积分 result = integrate_logspace(lambda k: k**(-2), 0.01, 100)
Gauss-Legendre 积分
高精度求积方法,对光滑函数非常有效。
- gauss_legendre_nodes_weights(order)
返回指定阶数的 Gauss-Legendre 节点和权重。
- 参数:
order – 阶数(支持 8, 12, 16)
- 返回:
(nodes, weights) 元组
- gauss_legendre_integrate(integrand, a, b, *, order=12)
在区间 [a, b] 上进行 Gauss-Legendre 积分。
- 参数:
integrand – 被积函数
a – 积分下限
b – 积分上限
order – 积分阶数(8, 12, 或 16)
- 返回:
积分结果
- gauss_legendre_integrate_batch(integrand, z_array, *, z_min=0.0, order=12)
批量 Gauss-Legendre 积分,对多个上限同时计算。
- 参数:
integrand – 被积函数
z_array – 上限数组
z_min – 积分下限
order – 积分阶数
- 返回:
积分结果数组
示例:
from hicosmo.utils import gauss_legendre_integrate_batch import jax.numpy as jnp # 计算多个上限的积分 z_values = jnp.array([0.5, 1.0, 2.0]) results = gauss_legendre_integrate_batch( lambda z: 1.0 / jnp.sqrt(0.3 * (1 + z)**3 + 0.7), z_values, z_min=0.0, order=12 )
累积积分
- cumulative_trapezoid(y, x)
累积梯形积分,返回从第一个点开始的累积积分值。
- 参数:
y – 函数值数组
x – 自变量数组
- 返回:
累积积分数组(第一个元素为0)
示例:
from hicosmo.utils import cumulative_trapezoid import jax.numpy as jnp x = jnp.linspace(0, 1, 100) y = jnp.ones_like(x) # 常数函数 cumulative = cumulative_trapezoid(y, x) # cumulative[-1] ≈ 1.0
- integrate_batch_cumulative(func, z_array, *, z_min=0.0, z_max=None, n_grid=4096)
使用累积网格和插值进行批量积分。适用于大量查询点。
- 参数:
func – 被积函数
z_array – 查询点数组
z_min – 积分下限
z_max – 网格上限(默认为 z_array 的最大值)
n_grid – 网格点数
- 返回:
积分结果数组
性能提示: 当查询点很多时,此方法比
gauss_legendre_integrate_batch更快。
分段积分
- integrate_segmented(func, a, b, *, n_segments=8, order=12)
分段 Gauss-Legendre 积分,适用于大积分区间。
- 参数:
func – 被积函数
a – 积分下限
b – 积分上限
n_segments – 分段数
order – 每段的 Gauss-Legendre 阶数
- 返回:
积分结果
使用场景: 当积分区间很大(如 [0, 1000])时,单次 Gauss-Legendre 积分精度不足,使用分段积分可以提高精度。
自适应Simpson积分
- integrate_adaptive_simpson(func, a, b, *, num=64, max_num=4096, tol=1e-6)
自适应 Simpson 积分,自动调整网格密度。
- 参数:
func – 被积函数
a – 积分下限
b – 积分上限
num – 初始网格点数(必须为偶数)
max_num – 最大网格点数(必须是 num 的整数倍)
tol – 收敛容差
- 返回:
积分结果
微分工具
一维梯度
- gradient_1d(values, coords)
计算一维网格上的导数,支持非均匀网格。
- 参数:
values – 函数值数组
coords – 坐标数组
- 返回:
导数数组
特点: - 端点使用3点Lagrange插值(二阶精度) - 内部点使用二阶中心差分 - 支持非均匀网格
示例:
from hicosmo.utils import gradient_1d import jax.numpy as jnp x = jnp.linspace(0, 2 * jnp.pi, 100) y = jnp.sin(x) dy_dx = gradient_1d(y, x) # ≈ cos(x)
有限差分梯度
- finite_difference_grad(func, x, *, step=1e-5)
使用中心有限差分计算标量函数的梯度。
- 参数:
func – 标量函数 f(x) -> scalar
x – 求值点(数组)
step – 差分步长
- 返回:
梯度数组
注意: 对于可微函数,推荐使用
jax.grad进行自动微分。
自动微分包装器
- grad(func)
返回函数的梯度函数(
jax.grad包装器)。
- jacobian(func)
返回函数的雅可比矩阵函数(
jax.jacobian包装器)。
- hessian(func)
返回函数的 Hessian 矩阵函数(
jax.hessian包装器)。示例:
from hicosmo.utils import grad, hessian import jax.numpy as jnp def f(x): return jnp.sum(x**2) grad_f = grad(f) hess_f = hessian(f) x = jnp.array([1.0, 2.0]) print(grad_f(x)) # [2., 4.]
求根工具
Newton-Raphson 法
- newton_root(func, x0, *, tol=1e-8, max_iter=64, deriv=None)
使用 Newton-Raphson 方法求解 f(x) = 0。
- 参数:
func – 目标函数 f(x)
x0 – 初始猜测值
tol – 收敛容差
max_iter – 最大迭代次数
deriv – 导数函数(可选,默认使用自动微分)
- 返回:
根的近似值
示例:
from hicosmo.utils import newton_root import jax.numpy as jnp # 求解 x² - 2 = 0 root = newton_root(lambda x: x**2 - 2, jnp.array(1.5)) print(f"√2 ≈ {root:.10f}") # 1.4142135624
二分法
- bisection_root(func, a, b, *, tol=1e-8, max_iter=128)
使用二分法求解 f(x) = 0(要求区间端点符号相反)。
- 参数:
func – 目标函数 f(x)
a – 区间左端点
b – 区间右端点
tol – 收敛容差
max_iter – 最大迭代次数
- 返回:
根的近似值(如果端点同号则返回 NaN)
示例:
from hicosmo.utils import bisection_root import jax.numpy as jnp # 求解 x³ - x - 1 = 0 root = bisection_root( lambda x: x**3 - x - 1, jnp.array(1.0), jnp.array(2.0) )
ODE 求解器
RK4 单步
- rk4_step(func, t, y, dt)
执行一步四阶 Runge-Kutta 积分。
- 参数:
func – ODE 右端项 f(t, y)
t – 当前时间
y – 当前状态
dt – 时间步长
- 返回:
下一时刻的状态
RK4 积分器
- odeint_rk4(func, y0, t_grid)
在固定网格上使用 RK4 方法求解 ODE y’ = f(t, y)。
- 参数:
func – ODE 右端项 f(t, y)
y0 – 初始条件
t_grid – 时间网格
- 返回:
解数组(每个时间点一行)
示例:
from hicosmo.utils import odeint_rk4 import jax.numpy as jnp # 求解 y' = -y, y(0) = 1 (解析解: y = e^(-t)) def rhs(t, y): return -y t = jnp.linspace(0, 5, 100) y = odeint_rk4(rhs, jnp.array(1.0), t)
- solve_ivp_rk4(func, y0, t0, t1, *, n_steps=256)
在区间 [t0, t1] 上使用固定步长 RK4 求解初值问题。
- 参数:
func – ODE 右端项 f(t, y)
y0 – 初始条件
t0 – 起始时间
t1 – 终止时间
n_steps – 步数
- 返回:
(t_grid, y_grid) 元组
宇宙学专用工具
声视界计算
基于 Eisenstein & Hu (1998) 拟合公式。
- sound_horizon_drag_eh98(H0, Omega_m, Omega_b, T_cmb)
计算拖拽时刻(drag epoch)的声视界 r_d。
- 参数:
H0 – 哈勃常数 [km/s/Mpc]
Omega_m – 物质密度参数
Omega_b – 重子密度参数
T_cmb – CMB 温度 [K]
- 返回:
声视界 r_d [Mpc]
示例:
from hicosmo.utils.jax_tools import sound_horizon_drag_eh98 r_d = sound_horizon_drag_eh98( H0=67.36, Omega_m=0.3153, Omega_b=0.0493, T_cmb=2.7255 ) print(f"r_d = {r_d:.2f} Mpc") # ≈ 147 Mpc
- sound_horizon_eh98(H0, Omega_m, Omega_b, T_cmb, z)
计算任意红移 z 处的声视界。
- 参数:
H0 – 哈勃常数 [km/s/Mpc]
Omega_m – 物质密度参数
Omega_b – 重子密度参数
T_cmb – CMB 温度 [K]
z – 红移
- 返回:
声视界 r_s(z) [Mpc]
性能对比
各积分方法的适用场景:
方法 |
精度 |
适用场景 |
|---|---|---|
|
O(h²) |
快速粗略估计 |
|
O(h⁴) |
光滑函数的标准选择 |
|
高 |
小区间高精度积分 |
|
O(h²) |
大量查询点的批量积分 |
|
高 |
大区间高精度积分 |
JIT 编译注意事项
所有函数都兼容 JAX JIT 编译:
from jax import jit
from hicosmo.utils import integrate_simpson
@jit
def my_function(params):
result = integrate_simpson(
lambda x: params[0] * x**2 + params[1],
0, 1
)
return result
注意事项:
积分区间 [a, b] 可以是动态值
网格点数 (num, n_grid) 必须是静态值
使用
static_argnums处理需要静态化的参数
完整示例
计算共动距离
import jax.numpy as jnp
from hicosmo.utils import integrate_batch_cumulative
# 宇宙学参数
H0 = 70.0 # km/s/Mpc
Omega_m = 0.3
c = 299792.458 # km/s
def E_z(z):
"""Hubble 参数 E(z) = H(z)/H0"""
return jnp.sqrt(Omega_m * (1 + z)**3 + (1 - Omega_m))
# 计算多个红移的共动距离
z_array = jnp.array([0.1, 0.5, 1.0, 2.0])
integrals = integrate_batch_cumulative(
lambda z: 1.0 / E_z(z),
z_array,
n_grid=4096
)
d_c = (c / H0) * integrals
for z, d in zip(z_array, d_c):
print(f"z = {z:.1f}: d_c = {d:.1f} Mpc")
参考文献
Eisenstein, D. J., & Hu, W. (1998). Baryonic Features in the Matter Transfer Function. ApJ, 496, 605.