张芷铭的个人博客

PyTorch分布式训练通过多设备并行计算加速模型训练,从基础DataParallelFSDP,覆盖不同规模的训练需求。

分布式训练类型

类型说明
数据并行数据批次拆分到不同设备,梯度汇总更新
模型并行模型拆分到不同设备
混合并行结合数据并行和模型并行

发展历程

版本特性
v1.0引入torch.distributed
v1.5推出DistributedDataParallel
v1.11引入FSDP

核心原理

数据并行梯度更新: $$\theta_{t+1} = \theta_t - \eta \cdot \frac{1}{N} \sum_{i=1}^N \nabla_\theta \mathcal{L}(x_i, y_i; \theta_t)$$

通信后端

后端适用场景
NCCLNVIDIA GPU最佳选择
GlooCPU训练
MPI高性能计算环境

适用场景

场景推荐方案
单机多卡DistributedDataParallel
多机多卡DistributedDataParallel
超大模型FSDP或模型并行
弹性训练torch.distributed.elastic

基础实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    model = model.to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    for batch in dataloader:
        outputs = ddp_model(batch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    dist.destroy_process_group()

FSDP(完全分片数据并行)

1
2
3
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = FSDP(model)  # 显著减少显存占用

性能优化技巧

技巧说明
学习率缩放$\eta_{\text{new}} = \eta \times \text{world_size}$
pin_memory=True加速数据传输
梯度累积模拟更大批次
num_workers调优避免I/O瓶颈

常见问题

问题解决方案
死锁确保所有rank通信操作匹配
显存不足使用激活检查点或梯度累积
性能瓶颈使用torch.profiler分析

Comments