FSDP(Fully Sharded Data Parallel)是 PyTorch 官方的”显存救星”:把模型参数、梯度、优化器状态切片分到所有 GPU,每张卡只存 1/N,需要时临时凑齐、用完释放。让 80 GB 单卡也能训百亿/千亿参数模型。

为什么 DDP 不够

DDP 的致命问题:每张 GPU 都存完整的模型 + 梯度 + 优化器状态

100 亿参数 FP32 模型:

组件显存
参数40 GB
梯度(FP32)40 GB
Adam 优化器状态(动量+二阶矩)80 GB
合计160 GB(单卡装不下)

冗余存储是大模型训练的核心瓶颈。

FSDP 核心思路

步骤行为
初始化分片把参数/梯度/优化器状态切成 N 份,每张 GPU 只存 1 份
前向传播算到某模块时通过 all-gather 临时凑齐完整模块,算完立即释放
反向传播同样凑齐算梯度,再 reduce-scatter 切片,每卡只留自己负责的梯度
优化器更新每卡独立更新自己的参数分片,无需同步完整模型

一句话:用到时凑齐,用完就清空

三大优势

  • 显存巨省:单卡显存降到约 1/N
  • 效率不低:通信被拆成多个小批次并与计算重叠,大模型场景吞吐反超 DDP(毕竟 DDP 跑不起来)
  • 易用性强:与 PyTorch 无缝衔接,DDP 改 FSDP 通常只需几行代码

何时用

  • 模型参数 ≥ 10 亿
  • 频繁出现 OOM
  • 需要在固定 GPU 资源上扩大模型规模

参数量小、单卡装得下时直接用 DDP,FSDP 的通信代价此时不划算。

与 ZeRO 的关系

FSDP 是 PyTorch 对 DeepSpeed ZeRO 思想的官方实现:

ZeRO Stage等价 FSDP 配置
ZeRO-1仅切优化器状态
ZeRO-2切优化器状态 + 梯度
ZeRO-3全部切片(参数 + 梯度 + 优化器状态)— FSDP 默认行为