MPI(Message Passing Interface)是跨节点分布式训练协议,通过消息传递实现设备间通信,适用于大规模集群环境。
MPI多机多卡训练原理
核心概念
| 概念 | 说明 |
|---|---|
| 进程(Process) | 每个GPU对应一个独立进程,通过rank标识唯一ID |
| 通信域(Communicator) | 管理进程组,如MPI.COMM_WORLD包含所有进程 |
| 集合通信 | AllReduce(全局梯度求和)、Broadcast(参数广播) |
工作流程
- 数据分片:每个进程加载部分数据
- 本地计算:各进程独立前向/反向传播
- 梯度同步:使用
AllReduce汇总梯度 - 参数更新:主进程广播更新后的参数
环境配置
pip install mpi4py torch torchvision
# 安装OpenMPI
sudo apt-get install openmpi-bin libopenmpi-dev网络配置:
- 所有节点免密SSH登录
- 节点间时钟同步
Python模板代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from mpi4py import MPI
# 初始化MPI环境
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
# 数据集分片
class CustomDataset(Dataset):
def __len__(self):
return 1000
def __getitem__(self, idx):
idx = (idx + rank) % 1000
return torch.randn(3, 28, 28), torch.randint(0, 10, (1,))
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
# 定义模型
model = nn.Sequential(
nn.Conv2d(3, 16, 3), nn.ReLU(), nn.Flatten(), nn.Linear(16*26*26, 10)
).cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练循环
for epoch in range(10):
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
# 梯度同步:AllReduce求和
for param in model.parameters():
grad_np = param.grad.data.cpu().numpy()
comm.Allreduce(MPI.IN_PLACE, grad_np, op=MPI.SUM)
param.grad.data = torch.from_numpy(grad_np / world_size).cuda()
optimizer.step()
if rank == 0:
torch.save(model.state_dict(), f"model_epoch{epoch}.pth")启动方式
单机多卡(4卡):
mpirun -n 4 python train.py多机多卡(2节点各4卡):
hostfile内容:
node1 slots=4
node2 slots=4启动命令:
mpirun -n 8 --hostfile hosts.txt python train.py优化技巧
| 技巧 | 说明 |
|---|---|
| 梯度压缩 | 减少通信数据量(如精度转FP16) |
| 异步通信 | 重叠计算与通信(IAllreduce替代AllReduce) |
| 日志分离 | 使用--output-filename log_output分离各进程日志 |
此模板支持数据并行。模型并行可结合
torch.distributed的DistributedDataParallel或ZeRO优化器。