FSDP(Fully Sharded Data Parallel)是 PyTorch 官方的”显存救星”:把模型参数、梯度、优化器状态切片分到所有 GPU,每张卡只存 1/N,需要时临时凑齐、用完释放。让 80 GB 单卡也能训百亿/千亿参数模型。
为什么 DDP 不够
DDP 的致命问题:每张 GPU 都存完整的模型 + 梯度 + 优化器状态。
100 亿参数 FP32 模型:
| 组件 | 显存 |
|---|---|
| 参数 | 40 GB |
| 梯度(FP32) | 40 GB |
| Adam 优化器状态(动量+二阶矩) | 80 GB |
| 合计 | 160 GB(单卡装不下) |
冗余存储是大模型训练的核心瓶颈。
FSDP 核心思路
| 步骤 | 行为 |
|---|---|
| 初始化分片 | 把参数/梯度/优化器状态切成 N 份,每张 GPU 只存 1 份 |
| 前向传播 | 算到某模块时通过 all-gather 临时凑齐完整模块,算完立即释放 |
| 反向传播 | 同样凑齐算梯度,再 reduce-scatter 切片,每卡只留自己负责的梯度 |
| 优化器更新 | 每卡独立更新自己的参数分片,无需同步完整模型 |
一句话:用到时凑齐,用完就清空。
三大优势
- 显存巨省:单卡显存降到约 1/N
- 效率不低:通信被拆成多个小批次并与计算重叠,大模型场景吞吐反超 DDP(毕竟 DDP 跑不起来)
- 易用性强:与 PyTorch 无缝衔接,DDP 改 FSDP 通常只需几行代码
何时用
- 模型参数 ≥ 10 亿
- 频繁出现 OOM
- 需要在固定 GPU 资源上扩大模型规模
参数量小、单卡装得下时直接用 DDP,FSDP 的通信代价此时不划算。
与 ZeRO 的关系
FSDP 是 PyTorch 对 DeepSpeed ZeRO 思想的官方实现:
| ZeRO Stage | 等价 FSDP 配置 |
|---|---|
| ZeRO-1 | 仅切优化器状态 |
| ZeRO-2 | 切优化器状态 + 梯度 |
| ZeRO-3 | 全部切片(参数 + 梯度 + 优化器状态)— FSDP 默认行为 |