PyTorch分布式训练通过多设备并行计算加速模型训练,从基础DataParallel到FSDP,覆盖不同规模的训练需求。
分布式训练类型
| 类型 | 说明 |
|---|
| 数据并行 | 数据批次拆分到不同设备,梯度汇总更新 |
| 模型并行 | 模型拆分到不同设备 |
| 混合并行 | 结合数据并行和模型并行 |
发展历程
| 版本 | 特性 |
|---|
| v1.0 | 引入torch.distributed包 |
| v1.5 | 推出DistributedDataParallel |
| v1.11 | 引入FSDP |
核心原理
数据并行梯度更新:
$$\theta_{t+1} = \theta_t - \eta \cdot \frac{1}{N} \sum_{i=1}^N \nabla_\theta \mathcal{L}(x_i, y_i; \theta_t)$$
通信后端
| 后端 | 适用场景 |
|---|
| NCCL | NVIDIA GPU最佳选择 |
| Gloo | CPU训练 |
| MPI | 高性能计算环境 |
适用场景
| 场景 | 推荐方案 |
|---|
| 单机多卡 | DistributedDataParallel |
| 多机多卡 | DistributedDataParallel |
| 超大模型 | FSDP或模型并行 |
| 弹性训练 | torch.distributed.elastic |
基础实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
| import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train(rank, world_size):
setup(rank, world_size)
model = model.to(rank)
ddp_model = DDP(model, device_ids=[rank])
for batch in dataloader:
outputs = ddp_model(batch)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
dist.destroy_process_group()
|
FSDP(完全分片数据并行)
1
2
3
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model) # 显著减少显存占用
|
性能优化技巧
| 技巧 | 说明 |
|---|
| 学习率缩放 | $\eta_{\text{new}} = \eta \times \text{world_size}$ |
pin_memory=True | 加速数据传输 |
| 梯度累积 | 模拟更大批次 |
num_workers调优 | 避免I/O瓶颈 |
常见问题
| 问题 | 解决方案 |
|---|
| 死锁 | 确保所有rank通信操作匹配 |
| 显存不足 | 使用激活检查点或梯度累积 |
| 性能瓶颈 | 使用torch.profiler分析 |
Comments