张芷铭的个人博客

MPI通信的训练

以下介绍基于MPI通信的多机多卡训练方式,并提供完整的Python代码模板。MPI(Message Passing Interface)是一种跨节点的分布式训练协议,通过消息传递实现设备间通信,适用于大规模集群环境。


一、MPI多机多卡训练原理

  1. 核心概念

    • 进程(Process):每个GPU对应一个独立进程,通过rank标识唯一ID(0为主进程)。
    • 通信域(Communicator):管理进程组(如MPI.COMM_WORLD包含所有进程)。
    • 集合通信(Collective Communication):如AllReduce(全局梯度求和)和Broadcast(参数广播),实现高效数据同步。
  2. 工作流程

    • 数据分片:每个进程加载部分数据(通过rank分配)。
    • 本地计算:各进程独立前向/反向传播。
    • 梯度同步:使用AllReduce汇总所有进程的梯度。
    • 参数更新:主进程广播更新后的参数。

二、环境配置

  1. 软件依赖

    1
    2
    3
    
    pip install mpi4py torch torchvision
    # 安装OpenMPI(Linux示例)
    sudo apt-get install openmpi-bin libopenmpi-dev
    
  2. 主机网络配置

    • 所有节点需免密SSH登录(使用ssh-keygenssh-copy-id)。
    • 节点间需时钟同步(如ntpdate)。

三、Python模板代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from mpi4py import MPI
import numpy as np

# 初始化MPI环境
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
torch.manual_seed(1)  # 确保所有进程初始参数一致

# 1. 数据集分片(模拟MNIST)
class CustomDataset(Dataset):
    def __len__(self):
        return 1000  # 数据总量
    def __getitem__(self, idx):
        idx = (idx + rank) % 1000  # 按rank分配数据
        return torch.randn(3, 28, 28), torch.randint(0, 10, (1,))

dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# 2. 定义模型(简单CNN)
model = nn.Sequential(
    nn.Conv2d(3, 16, 3),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(16*26*26, 10)
).cuda() if torch.cuda.is_available() else nn.Sequential(...)

optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# 3. 训练循环
for epoch in range(10):
    for data, target in dataloader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()

        # 梯度同步:AllReduce求和
        for param in model.parameters():
            grad_np = param.grad.data.cpu().numpy()
            comm.Allreduce(MPI.IN_PLACE, grad_np, op=MPI.SUM)  # 梯度求和
            param.grad.data = torch.from_numpy(grad_np / world_size).cuda()  # 求平均

        optimizer.step()

    # 主进程保存模型
    if rank == 0:
        torch.save(model.state_dict(), f"model_epoch{epoch}.pth")

四、启动方式

  1. 单机多卡(4卡)

    1
    
    mpirun -n 4 python train.py
    
  2. 多机多卡(2节点,各4卡)

    • hostfile内容(保存为hosts.txt):
      1
      2
      
      node1 slots=4  # 主机名+可用GPU数
      node2 slots=4
      
    • 启动命令
      1
      
      mpirun -n 8 --hostfile hosts.txt python train.py
      

五、优化技巧与常见问题

  1. 性能优化

    • 梯度压缩:减少通信数据量(如精度转FP16)。
    • 异步通信:重叠计算与通信(IAllreduce替代AllReduce)。
  2. 调试建议

    • 日志分离:使用--output-filename log_output将各进程日志存不同文件。
    • 错误处理:任一进程失败则整个任务终止,需检查节点间网络和依赖一致性。

此模板支持数据并行,若需模型并行(如超大模型),可结合torch.distributedDistributedDataParallel或ZeRO优化器。完整示例见https://www.mindspore.cn/tutorials/zh-CN/master/parallel/mpirun.html。

💬 评论