Adam(Adaptive Moment Estimation)融合动量机制与自适应学习率,是深度学习领域的默认优化器,工业界实际使用的是 AdamW。
为什么需要 Adam
传统优化器的核心缺陷:
| 优化器 | 问题 |
|---|---|
| 朴素 SGD | 固定学习率,收敛慢、易震荡 |
| 带动量 SGD | 全局固定学习率,调参难 |
| AdaGrad | 历史梯度累积,后期学习率衰减至 0 |
| RMSprop | 缺少动量机制,收敛不稳定 |
Adam 融合了动量机制(Momentum SGD)和自适应学习率(RMSprop),通过偏差修正解决训练初期的矩估计偏置问题。
数学原理
符号定义
| 符号 | 定义 |
|---|---|
| $\theta$ | 模型参数(权重、偏置) |
| $g_t$ | 第 t 步梯度 |
| $\alpha$ | 全局学习率 |
| $\beta_1$ | 一阶矩衰减率(动量系数) |
| $\beta_2$ | 二阶矩衰减率 |
| $m_t$ | 一阶矩估计(梯度 EMA) |
| $v_t$ | 二阶矩估计(梯度平方 EMA) |
| $\hat{m}_t, \hat{v}_t$ | 偏差修正后的矩估计 |
| $\epsilon$ | 防除零常数 |
核心步骤
1. 一阶矩估计(动量)
$$m_t = \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t$$
默认 $\beta_1=0.9$,保留 90% 历史动量,叠加 10% 当前梯度。
2. 二阶矩估计(自适应学习率)
$$v_t = \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot g_t^2$$
默认 $\beta_2=0.999$。梯度大的参数得较小学习率,梯度小的参数得较大学习率。
3. 偏差修正
初始值 $m_0=v_0=0$ 导致训练初期矩估计偏小。修正公式:
$$\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}$$
4. 参数更新
$$\theta_{t+1} = \theta_t - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$$
- 分子:动量平滑后的更新方向
- 分母:自适应学习率缩放系数
- $\epsilon$ 通常取 1e-8
伪代码
| |
AdamW:工业界实际使用版本
原始 Adam 的 L2 正则化存在逻辑错误:权重衰减项被二阶矩缩放,无法有效抑制过拟合。
AdamW 改进:解耦权重衰减与梯度更新
$$\theta_{t+1} = \theta_t - \alpha \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \cdot \theta_t \right)$$
| 特性 | 原始 Adam | AdamW |
|---|---|---|
| 权重衰减 | 耦合进梯度 | 解耦,直接作用于参数 |
| 正则化效果 | 差 | 好 |
| 工业界应用 | 几乎淘汰 | 默认选择 |
超参数配置
| 超参数 | 默认值 | 说明 |
|---|---|---|
| 学习率 $\alpha$ | 1e-3(通用) 1e-4~3e-4(LLM) | 唯一需重点调优的参数 |
| $\beta_1$ | 0.9 | 99% 场景无需修改 |
| $\beta_2$ | 0.999 | Transformer 训练偶尔用 0.95 |
| $\epsilon$ | 1e-8 | FP16 混合精度可用 1e-6 |
| 权重衰减 $\lambda$ | 1e-4(LLM) 1e-2(CV) | AdamW 专用 |
变种与适用场景
| 变种 | 改进 | 适用场景 |
|---|---|---|
| AdamW | 解耦权重衰减 | 通用默认 |
| Adamax | 无穷范数替代二阶矩 | 稀疏梯度、嵌入层 |
| NAdam | Nesterov 动量 | 加速收敛 |
| LAMB | 层自适应学习率 | 大规模分布式训练 |
| AdaFactor | 二阶矩低秩分解 | 显存受限场景 |
| 8-bit Adam | 8-bit 存储矩 | 消费级显卡微调 |
优缺点
优点
- 收敛快,自适应学习率
- 超参数鲁棒,调参简单
- 自带动量,稳定性强
- 适配分布式训练
缺点
- 显存占用高(需存 $m_t$ 和 $v_t$)
- 训练后期可能震荡
- 对异常梯度敏感
工程实践
| |
最佳实践:
- 优先使用 AdamW
- 配合线性预热 + 余弦衰减学习率调度
- 加入梯度裁剪(max_norm=1.0)
- FP16 训练时将矩存为 FP32
- 低显存场景使用 8-bit AdamW(bitsandbytes)
张芷铭的个人博客
Comments