张芷铭的个人博客

model.train() 和 model.eval() 控制模型状态,影响 BatchNorm 和 Dropout 层的行为。

核心区别

方法BatchNormDropout用途
train()使用 batch 统计,更新全局均值方差随机丢弃神经元训练阶段
eval()使用全局统计,不更新全部神经元参与推理阶段

代码示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(3)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        return self.dropout(self.bn(x))

model = Model()

# 训练模式
model.train()
output_train = model(x)

# 推理模式
model.eval()
output_eval = model(x)

行为详解

BatchNorm

  • train():计算当前 batch 均值方差,更新 running_mean/running_var
  • eval():使用训练时累计的 running_mean/running_var

Dropout

  • train():以概率 p 随机置零
  • eval():所有神经元参与,输出缩放为 $(1-p)$ 倍

注意事项

推理时忘记 model.eval() 会导致:

  • BatchNorm 继续更新统计量
  • Dropout 随机丢弃,输出不稳定

Comments