JAX 是 Google 开发的数值计算库,基于 NumPy API,提供自动微分、JIT 编译和 GPU/TPU 加速。

核心特性

特性说明
jax.grad自动微分
jax.jitJIT 编译加速
jax.vmap自动向量化
jax.pmap多设备并行

自动微分

import jax
import jax.numpy as jnp
 
def f(x):
    return x ** 2
 
grad_f = jax.grad(f)
print(grad_f(3.0))  # 6.0 (2 * 3)
 
# 高阶导数
grad2_f = jax.grad(grad_f)
print(grad2_f(3.0))  # 2.0

JIT 编译

@jax.jit
def compute(x):
    return jnp.sum(x ** 2)
 
# 首次调用编译,后续调用极快
result = compute(jnp.ones(1000))

向量化 vmap

def f(x):
    return x ** 2
 
# 批量处理
vmap_f = jax.vmap(f)
print(vmap_f(jnp.array([1.0, 2.0, 3.0])))  # [1.0, 4.0, 9.0]

组合使用

@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        pred = jnp.dot(x, p)
        return jnp.mean((pred - y) ** 2)
 
    grads = jax.grad(loss_fn)(params)
    return params - 0.1 * grads  # 梯度下降

设备管理

# 查看设备
print(jax.devices())
 
# 放置数据到设备
x = jax.device_put(jnp.ones((3, 3)))
print(x.device())

JAX vs NumPy vs PyTorch

特性JAXNumPyPyTorch
自动微分✅ grad✅ autograd
JIT 编译✅ jit✅ torch.compile
GPU 加速
函数式风格

应用场景

  • 深度学习:模型训练、自动微分
  • 科学计算:物理模拟、优化问题
  • 大规模计算:TPU 集群分布式训练