JAX 是 Google 开发的数值计算库,基于 NumPy API,提供自动微分、JIT 编译和 GPU/TPU 加速。
核心特性
| 特性 | 说明 |
|---|
jax.grad | 自动微分 |
jax.jit | JIT 编译加速 |
jax.vmap | 自动向量化 |
jax.pmap | 多设备并行 |
自动微分
1
2
3
4
5
6
7
8
9
10
11
12
| import jax
import jax.numpy as jnp
def f(x):
return x ** 2
grad_f = jax.grad(f)
print(grad_f(3.0)) # 6.0 (2 * 3)
# 高阶导数
grad2_f = jax.grad(grad_f)
print(grad2_f(3.0)) # 2.0
|
JIT 编译
1
2
3
4
5
6
| @jax.jit
def compute(x):
return jnp.sum(x ** 2)
# 首次调用编译,后续调用极快
result = compute(jnp.ones(1000))
|
向量化 vmap
1
2
3
4
5
6
| def f(x):
return x ** 2
# 批量处理
vmap_f = jax.vmap(f)
print(vmap_f(jnp.array([1.0, 2.0, 3.0]))) # [1.0, 4.0, 9.0]
|
组合使用
1
2
3
4
5
6
7
8
| @jax.jit
def train_step(params, x, y):
def loss_fn(p):
pred = jnp.dot(x, p)
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
return params - 0.1 * grads # 梯度下降
|
设备管理
1
2
3
4
5
6
| # 查看设备
print(jax.devices())
# 放置数据到设备
x = jax.device_put(jnp.ones((3, 3)))
print(x.device())
|
JAX vs NumPy vs PyTorch
| 特性 | JAX | NumPy | PyTorch |
|---|
| 自动微分 | ✅ grad | ❌ | ✅ autograd |
| JIT 编译 | ✅ jit | ❌ | ✅ torch.compile |
| GPU 加速 | ✅ | ❌ | ✅ |
| 函数式风格 | ✅ | ✅ | ❌ |
应用场景
- 深度学习:模型训练、自动微分
- 科学计算:物理模拟、优化问题
- 大规模计算:TPU 集群分布式训练
Comments