张芷铭的个人博客

Google DeepMind《How To Scale Your Model》系列第三篇精华,大模型分布式训练/推理的张量分片底层理论与方法论。

核心基础概念

为什么要做张量分片

驱动因素说明
显存约束大模型参数/激活无法放入单卡HBM,必须切分分布多设备
性能需求多卡并行提升总算力,降低训练step耗时/推理延迟

核心术语

术语定义
Device Mesh物理设备的逻辑拓扑,如2x2的4卡TPU,Mesh为{'X':2, 'Y':2}
Sharding把张量的逻辑轴映射到Mesh的物理轴,决定如何切分
Global Shape张量的逻辑形状,代码中.shape返回该值
Local Shape单设备上实际存储的张量分片形状
Contracting Dimension矩阵乘法中需要求和的维度,如A[I,J] @ B[J,K]中的J轴
{U_X}未归约标记,表示沿Mesh的X轴仅完成部分和计算

统一分片记号系统

标准格式:A[I_X, J_Y] 表示张量A的第一个逻辑轴I沿Mesh的X轴分片,第二个逻辑轴J沿Y轴分片。

记号含义
A[I_X, J_Y]逻辑轴I沿X轴分片,J沿Y轴分片
A[I_{XY}, J]逻辑轴I沿X+Y轴合并分片,单卡持有`1/(
A[I, J]无Mesh下标,全量复制在所有设备

禁忌:同一Mesh轴不能同时分配给张量的两个逻辑轴,A[I_X, J_X]是非法分片。

JAX代码实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import jax
import jax.numpy as jnp

# 定义Mesh
mesh = jax.make_mesh(axis_shapes=(4, 2), axis_names=('X', 'Y'))

# 分片规则
def P(*args):
  return jax.NamedSharding(mesh, jax.sharding.PartitionSpec(*args))

# 声明分片张量
A = jnp.zeros((8, 2048), dtype=jnp.bfloat16, device=P('X', 'Y'))  # A[I_X, J_Y]
B = jnp.zeros((2048, 8192), dtype=jnp.bfloat16, device=P(None, 'Y'))  # B[I, J_Y]

# 分片矩阵乘法
y = jax.jit(lambda A,B: jnp.einsum('BD,DF->BF', A,B), out_shardings=P('X','Y'))(A,B)

分片矩阵乘法的4大核心场景

矩阵乘法 C = A @ B,其中A[I,J]B[J,K]C[I,K]J为收缩维度。

Case 1:收缩维度都不分片

示例:A[I_X, J] @ B[J, K_Y] → C[I_X, K_Y]

零通信,直接本地块矩阵乘法。每个设备上收缩维度J完整,求和完全在本地完成。

Case 2:仅一个乘数的收缩维度分片

示例:A[I, J_X] @ B[J, K] → C[I, K]

单卡仅持有部分J,无法完成完整求和。

解法

  • 标准:先AllGather补全J维度,再做本地matmul
  • 替代:先做本地matmul得到部分和,再AllReduce(小batch场景成本更低)

Case 3:两个乘数的收缩维度沿同一Mesh轴分片

示例:A[I, J_X] @ B[J_X, K] → C[I, K]{U_X}

可先做本地matmul得到部分和(标记{U_X})。

解法

  • AllReduce_X 得到全量复制的C[I,K](成本2倍AllGather)
  • ReduceScatter 得到分片的C[I, K_X]推荐,成本与AllGather一致)

Case 4:非收缩维度沿同一Mesh轴分片

示例:A[I_X, J] @ B[J, K_X] → 非法分片

同一Mesh轴X被两个非收缩维度占用,信息完全缺失。

唯一解法:先AllGather其中一个矩阵的分片轴。

四大核心集体通信原语

前提

  • 带宽受限(>10MB):耗时仅由数据总量和双向带宽决定,与设备数无关
  • 延迟受限(<45kB for TPU v5e):耗时由单跳延迟和跳数决定

核心原语对比

原语定义示例带宽成本核心场景
AllGather把分片收集成全量,复制到每个设备AllGather_X(A[I_X,J]) → A[I,J]数据量/双向带宽补全分片的收缩维度
ReduceScatter归约部分和后分片到各设备ReduceScatter_{X,K}(C[I,K]{U_X}) → C[I,K_X]同AllGatherFSDP/ZeRO,省显存
AllReduce全局归约,结果全量复制AllReduce_X(C[I,K]{U_X}) → C[I,K]2×AllGather梯度同步
AllToAll分片轴转移,等价分片转置AllToAll_{X,J}(A[I_X,J]) → A[I,J_X]AllGather/4MoE专家路由

数学本质

AllGather 和 ReduceScatter 互为转置:前向用AllGather,反向自动对应ReduceScatter。

工程落地核心结论

分片选型优先级

  1. 优先选择Case 1零通信分片,让收缩维度不分片
  2. 必须通信时,优先选择成本最低的原语:AllToAll > ReduceScatter/AllGather > AllReduce
  3. 大模型训练优先用ReduceScatter替代AllReduce,保留分片状态省显存

通信计算重叠

Collective Matmul实现通信与计算流水重叠:对前序块启动通信的同时计算当前块,隐藏通信延迟。

Transformer FFN分片策略

核心计算:In[B,D] @ W_in[D,F] @ W_out[F,D]

显存约束下,选择张量并行:将F维度沿Mesh轴分片,仅需一次AllGather+一次ReduceScatter。

小batch vs 大batch

场景推荐策略
小batch本地matmul + AllReduce,通信成本更低
大batchAllGather + 本地matmul,算力利用率更高

速记口诀

1
2
3
4
5
分片看收缩,无缩零通信
一缩先Gather,两缩先本地
同轴必冲突,Gather解矛盾
通信选低价,ToAll最划算
ReduceScatter省显存,AllReduce最后选

Comments