张芷铭的个人博客

MPI(Message Passing Interface)是跨节点分布式训练协议,通过消息传递实现设备间通信,适用于大规模集群环境。

MPI多机多卡训练原理

核心概念

概念说明
进程(Process)每个GPU对应一个独立进程,通过rank标识唯一ID
通信域(Communicator)管理进程组,如MPI.COMM_WORLD包含所有进程
集合通信AllReduce(全局梯度求和)、Broadcast(参数广播)

工作流程

  1. 数据分片:每个进程加载部分数据
  2. 本地计算:各进程独立前向/反向传播
  3. 梯度同步:使用AllReduce汇总梯度
  4. 参数更新:主进程广播更新后的参数

环境配置

1
2
3
pip install mpi4py torch torchvision
# 安装OpenMPI
sudo apt-get install openmpi-bin libopenmpi-dev

网络配置

  • 所有节点免密SSH登录
  • 节点间时钟同步

Python模板代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from mpi4py import MPI

# 初始化MPI环境
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()

# 数据集分片
class CustomDataset(Dataset):
    def __len__(self):
        return 1000
    def __getitem__(self, idx):
        idx = (idx + rank) % 1000
        return torch.randn(3, 28, 28), torch.randint(0, 10, (1,))

dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# 定义模型
model = nn.Sequential(
    nn.Conv2d(3, 16, 3), nn.ReLU(), nn.Flatten(), nn.Linear(16*26*26, 10)
).cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 训练循环
for epoch in range(10):
    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()

        # 梯度同步:AllReduce求和
        for param in model.parameters():
            grad_np = param.grad.data.cpu().numpy()
            comm.Allreduce(MPI.IN_PLACE, grad_np, op=MPI.SUM)
            param.grad.data = torch.from_numpy(grad_np / world_size).cuda()

        optimizer.step()

    if rank == 0:
        torch.save(model.state_dict(), f"model_epoch{epoch}.pth")

启动方式

单机多卡(4卡)

1
mpirun -n 4 python train.py

多机多卡(2节点各4卡)

hostfile内容:

1
2
node1 slots=4
node2 slots=4

启动命令:

1
mpirun -n 8 --hostfile hosts.txt python train.py

优化技巧

技巧说明
梯度压缩减少通信数据量(如精度转FP16)
异步通信重叠计算与通信(IAllreduce替代AllReduce
日志分离使用--output-filename log_output分离各进程日志

此模板支持数据并行。模型并行可结合torch.distributedDistributedDataParallel或ZeRO优化器。

Comments