全量激活重计算(Full Activation Recomputation)是大模型训练中最彻底的"用算力换显存"技术:前向只保留每层输入,反向时整层重跑前向再算梯度。
核心原理
不存中间激活,反向时整层重算,把显存压力转化为计算开销。
为什么需要它
训练大模型时,激活显存是最大瓶颈之一:
- 激活大小 ≈ $O(B \times T^2 \times C)$(B=batch、T=序列长、C=隐层维度)
- 长序列/大batch下,激活可占几十GB甚至上百GB
- 全量重计算可把激活显存从$O(L \times B \times T \times C)$压到$O(B \times T \times C)$
完整流程
前向传播:
- 对每一层,只保存层输入作为检查点
- 该层内所有中间激活(QKV、Score矩阵、FFN中间态)全部丢弃
- 只保留最终输出给下一层
反向传播:
- 从输出层往输入层走
- 对当前层:先用保存的输入完整重跑一次前向
- 再用重算的激活计算参数梯度
- 梯度计算完,再次丢弃该层激活
全量 vs 部分/选择性重计算
| 策略 | 做法 | 显存节省 | 计算overhead | 适用场景 |
|---|---|---|---|---|
| 全量重计算 | 每层只存输入,反向整层重算 | 最大 | 最高(+30%~40%) | 显存极度紧张 |
| 部分重计算 | 仅前N层重算 | 中等 | 中等 | 显存略紧 |
| 选择性重计算 | 只重算高显存模块(如Attention) | 高 | 低(+10%~20%) | 主流选择 |
优缺点
优点:
- 极致省显存,可训练更深/更长序列/更大batch
- 实现简单,框架原生支持
- 兼容性好,与TP、PP、ZeRO等可叠加
缺点:
- 计算开销大,训练速度慢30%~40%
- 无计算复用,所有中间结果都要重算
- 频繁重算增加GPU指令与访存压力
主流框架开启方式
Megatron-LM
| |
PyTorch
| |
NVIDIA NeMo
| |
适用场景
- 显存严重不足:必须训练超大规模模型/超长序列
- 离线训练:对训练时间不敏感
- 叠加优化:全量重计算 + ZeRO + TP = 极限显存压缩
进阶优化
| 技术 | 说明 |
|---|---|
| 重叠重计算 | 重算与通信/IO并行,隐藏延迟 |
| Flash Attention内置重算 | Attention模块自动重算Score |
| 选择性重计算 | 只重算Attention等高显存模块 |
总结
全量激活重计算是大模型训练的显存兜底方案:用30%~40%的速度代价,换取激活显存的极致压缩,让超大模型/长序列训练成为可能。
张芷铭的个人博客
Comments