Google DeepMind《How To Scale Your Model》系列第三篇精华,大模型分布式训练/推理的张量分片底层理论与方法论。
核心基础概念
为什么要做张量分片
| 驱动因素 | 说明 |
|---|---|
| 显存约束 | 大模型参数/激活无法放入单卡HBM,必须切分分布多设备 |
| 性能需求 | 多卡并行提升总算力,降低训练step耗时/推理延迟 |
核心术语
| 术语 | 定义 |
|---|---|
| Device Mesh | 物理设备的逻辑拓扑,如2x2的4卡TPU,Mesh为{'X':2, 'Y':2} |
| Sharding | 把张量的逻辑轴映射到Mesh的物理轴,决定如何切分 |
| Global Shape | 张量的逻辑形状,代码中.shape返回该值 |
| Local Shape | 单设备上实际存储的张量分片形状 |
| Contracting Dimension | 矩阵乘法中需要求和的维度,如A[I,J] @ B[J,K]中的J轴 |
| {U_X} | 未归约标记,表示沿Mesh的X轴仅完成部分和计算 |
统一分片记号系统
标准格式:A[I_X, J_Y] 表示张量A的第一个逻辑轴I沿Mesh的X轴分片,第二个逻辑轴J沿Y轴分片。
| 记号 | 含义 |
|---|---|
A[I_X, J_Y] | 逻辑轴I沿X轴分片,J沿Y轴分片 |
A[I_{XY}, J] | 逻辑轴I沿X+Y轴合并分片,单卡持有`1/( |
A[I, J] | 无Mesh下标,全量复制在所有设备 |
禁忌:同一Mesh轴不能同时分配给张量的两个逻辑轴,A[I_X, J_X]是非法分片。
JAX代码实现
| |
分片矩阵乘法的4大核心场景
矩阵乘法 C = A @ B,其中A[I,J]、B[J,K]、C[I,K],J为收缩维度。
Case 1:收缩维度都不分片
示例:A[I_X, J] @ B[J, K_Y] → C[I_X, K_Y]
零通信,直接本地块矩阵乘法。每个设备上收缩维度J完整,求和完全在本地完成。
Case 2:仅一个乘数的收缩维度分片
示例:A[I, J_X] @ B[J, K] → C[I, K]
单卡仅持有部分J,无法完成完整求和。
解法:
- 标准:先
AllGather补全J维度,再做本地matmul - 替代:先做本地matmul得到部分和,再
AllReduce(小batch场景成本更低)
Case 3:两个乘数的收缩维度沿同一Mesh轴分片
示例:A[I, J_X] @ B[J_X, K] → C[I, K]{U_X}
可先做本地matmul得到部分和(标记{U_X})。
解法:
AllReduce_X得到全量复制的C[I,K](成本2倍AllGather)ReduceScatter得到分片的C[I, K_X](推荐,成本与AllGather一致)
Case 4:非收缩维度沿同一Mesh轴分片
示例:A[I_X, J] @ B[J, K_X] → 非法分片
同一Mesh轴X被两个非收缩维度占用,信息完全缺失。
唯一解法:先AllGather其中一个矩阵的分片轴。
四大核心集体通信原语
前提
- 带宽受限(>10MB):耗时仅由数据总量和双向带宽决定,与设备数无关
- 延迟受限(<45kB for TPU v5e):耗时由单跳延迟和跳数决定
核心原语对比
| 原语 | 定义 | 示例 | 带宽成本 | 核心场景 |
|---|---|---|---|---|
| AllGather | 把分片收集成全量,复制到每个设备 | AllGather_X(A[I_X,J]) → A[I,J] | 数据量/双向带宽 | 补全分片的收缩维度 |
| ReduceScatter | 归约部分和后分片到各设备 | ReduceScatter_{X,K}(C[I,K]{U_X}) → C[I,K_X] | 同AllGather | FSDP/ZeRO,省显存 |
| AllReduce | 全局归约,结果全量复制 | AllReduce_X(C[I,K]{U_X}) → C[I,K] | 2×AllGather | 梯度同步 |
| AllToAll | 分片轴转移,等价分片转置 | AllToAll_{X,J}(A[I_X,J]) → A[I,J_X] | AllGather/4 | MoE专家路由 |
数学本质
AllGather 和 ReduceScatter 互为转置:前向用AllGather,反向自动对应ReduceScatter。
工程落地核心结论
分片选型优先级
- 优先选择Case 1零通信分片,让收缩维度不分片
- 必须通信时,优先选择成本最低的原语:
AllToAll > ReduceScatter/AllGather > AllReduce - 大模型训练优先用
ReduceScatter替代AllReduce,保留分片状态省显存
通信计算重叠
用Collective Matmul实现通信与计算流水重叠:对前序块启动通信的同时计算当前块,隐藏通信延迟。
Transformer FFN分片策略
核心计算:In[B,D] @ W_in[D,F] @ W_out[F,D]
显存约束下,选择张量并行:将F维度沿Mesh轴分片,仅需一次AllGather+一次ReduceScatter。
小batch vs 大batch
| 场景 | 推荐策略 |
|---|---|
| 小batch | 本地matmul + AllReduce,通信成本更低 |
| 大batch | AllGather + 本地matmul,算力利用率更高 |
速记口诀
| |
张芷铭的个人博客
Comments