张芷铭的个人博客

全量激活重计算(Full Activation Recomputation)是大模型训练中最彻底的"用算力换显存"技术:前向只保留每层输入,反向时整层重跑前向再算梯度。

核心原理

不存中间激活,反向时整层重算,把显存压力转化为计算开销。

为什么需要它

训练大模型时,激活显存是最大瓶颈之一:

  • 激活大小 ≈ $O(B \times T^2 \times C)$(B=batch、T=序列长、C=隐层维度)
  • 长序列/大batch下,激活可占几十GB甚至上百GB
  • 全量重计算可把激活显存从$O(L \times B \times T \times C)$压到$O(B \times T \times C)$

完整流程

前向传播

  1. 对每一层,只保存层输入作为检查点
  2. 该层内所有中间激活(QKV、Score矩阵、FFN中间态)全部丢弃
  3. 只保留最终输出给下一层

反向传播

  1. 从输出层往输入层走
  2. 对当前层:先用保存的输入完整重跑一次前向
  3. 再用重算的激活计算参数梯度
  4. 梯度计算完,再次丢弃该层激活

全量 vs 部分/选择性重计算

策略做法显存节省计算overhead适用场景
全量重计算每层只存输入,反向整层重算最大最高(+30%~40%)显存极度紧张
部分重计算仅前N层重算中等中等显存略紧
选择性重计算只重算高显存模块(如Attention)低(+10%~20%)主流选择

优缺点

优点

  • 极致省显存,可训练更深/更长序列/更大batch
  • 实现简单,框架原生支持
  • 兼容性好,与TP、PP、ZeRO等可叠加

缺点

  • 计算开销大,训练速度慢30%~40%
  • 无计算复用,所有中间结果都要重算
  • 频繁重算增加GPU指令与访存压力

主流框架开启方式

Megatron-LM

1
2
3
--recompute-activations \
--recompute-granularity full \
--recompute-method uniform

PyTorch

1
2
3
4
5
from torch.utils.checkpoint import checkpoint

model = checkpoint(model, input)
# 或每层单独wrap
layer = checkpoint_wrapper(layer)

NVIDIA NeMo

1
2
activations_checkpoint_granularity: full
activations_checkpoint_method: block

适用场景

  • 显存严重不足:必须训练超大规模模型/超长序列
  • 离线训练:对训练时间不敏感
  • 叠加优化:全量重计算 + ZeRO + TP = 极限显存压缩

进阶优化

技术说明
重叠重计算重算与通信/IO并行,隐藏延迟
Flash Attention内置重算Attention模块自动重算Score
选择性重计算只重算Attention等高显存模块

总结

全量激活重计算是大模型训练的显存兜底方案:用30%~40%的速度代价,换取激活显存的极致压缩,让超大模型/长序列训练成为可能。

Comments