张芷铭的个人博客

JAX 是 Google 开发的数值计算库,基于 NumPy API,提供自动微分、JIT 编译和 GPU/TPU 加速。

核心特性

特性说明
jax.grad自动微分
jax.jitJIT 编译加速
jax.vmap自动向量化
jax.pmap多设备并行

自动微分

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

grad_f = jax.grad(f)
print(grad_f(3.0))  # 6.0 (2 * 3)

# 高阶导数
grad2_f = jax.grad(grad_f)
print(grad2_f(3.0))  # 2.0

JIT 编译

1
2
3
4
5
6
@jax.jit
def compute(x):
    return jnp.sum(x ** 2)

# 首次调用编译,后续调用极快
result = compute(jnp.ones(1000))

向量化 vmap

1
2
3
4
5
6
def f(x):
    return x ** 2

# 批量处理
vmap_f = jax.vmap(f)
print(vmap_f(jnp.array([1.0, 2.0, 3.0])))  # [1.0, 4.0, 9.0]

组合使用

1
2
3
4
5
6
7
8
@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        pred = jnp.dot(x, p)
        return jnp.mean((pred - y) ** 2)

    grads = jax.grad(loss_fn)(params)
    return params - 0.1 * grads  # 梯度下降

设备管理

1
2
3
4
5
6
# 查看设备
print(jax.devices())

# 放置数据到设备
x = jax.device_put(jnp.ones((3, 3)))
print(x.device())

JAX vs NumPy vs PyTorch

特性JAXNumPyPyTorch
自动微分✅ grad✅ autograd
JIT 编译✅ jit✅ torch.compile
GPU 加速
函数式风格

应用场景

  • 深度学习:模型训练、自动微分
  • 科学计算:物理模拟、优化问题
  • 大规模计算:TPU 集群分布式训练

Comments