model.train() 和 model.eval() 控制模型状态,影响 BatchNorm 和 Dropout 层的行为。
核心区别
| 方法 | BatchNorm | Dropout | 用途 |
|---|---|---|---|
| train() | 使用 batch 统计,更新全局均值方差 | 随机丢弃神经元 | 训练阶段 |
| eval() | 使用全局统计,不更新 | 全部神经元参与 | 推理阶段 |
代码示例
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(3)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
return self.dropout(self.bn(x))
model = Model()
# 训练模式
model.train()
output_train = model(x)
# 推理模式
model.eval()
output_eval = model(x)行为详解
BatchNorm:
- train():计算当前 batch 均值方差,更新 running_mean/running_var
- eval():使用训练时累计的 running_mean/running_var
Dropout:
- train():以概率 p 随机置零
- eval():所有神经元参与,输出缩放为 倍
注意事项
推理时忘记 model.eval() 会导致:
- BatchNorm 继续更新统计量
- Dropout 随机丢弃,输出不稳定