PyTorch分布式训练通过多设备并行计算加速模型训练,从基础DataParallelFSDP,覆盖不同规模的训练需求。

分布式训练类型

类型说明
数据并行数据批次拆分到不同设备,梯度汇总更新
模型并行模型拆分到不同设备
混合并行结合数据并行和模型并行

发展历程

版本特性
v1.0引入torch.distributed
v1.5推出DistributedDataParallel
v1.11引入FSDP

核心原理

数据并行梯度更新:

通信后端

后端适用场景
NCCLNVIDIA GPU最佳选择
GlooCPU训练
MPI高性能计算环境

适用场景

场景推荐方案
单机多卡DistributedDataParallel
多机多卡DistributedDataParallel
超大模型FSDP或模型并行
弹性训练torch.distributed.elastic

基础实现

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
 
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
 
def train(rank, world_size):
    setup(rank, world_size)
    model = model.to(rank)
    ddp_model = DDP(model, device_ids=[rank])
 
    for batch in dataloader:
        outputs = ddp_model(batch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
 
    dist.destroy_process_group()

FSDP(完全分片数据并行)

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 
model = FSDP(model)  # 显著减少显存占用

性能优化技巧

技巧说明
学习率缩放
pin_memory=True加速数据传输
梯度累积模拟更大批次
num_workers调优避免I/O瓶颈

常见问题

问题解决方案
死锁确保所有rank通信操作匹配
显存不足使用激活检查点或梯度累积
性能瓶颈使用torch.profiler分析