反向传播利用链式求导法则,把误差逐层传递,高效计算每个参数的梯度。训练计算量约为推理的 3 倍。
核心问题
神经网络是多层嵌套复合函数,目标是找到最优参数使损失最小。
梯度下降:沿梯度反方向更新参数。
$$W_{new} = W_{old} - lr \times \frac{\partial Loss}{\partial W}$$
链式求导法则
单变量:$\frac{dy}{dx} = \frac{dy}{dg} \times \frac{dg}{dx}$
多变量:$\frac{\partial z}{\partial x} = \frac{\partial z}{\partial u} \times \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \times \frac{\partial v}{\partial x}$
单层线性层反向传播
前向:$Y = X @ W + b$
反向目标:
- 算参数梯度 $dW$、$db$:用于更新参数
- 算输入梯度 $dX$:传递误差给前一层
| 梯度 | 公式 |
|---|---|
| $dW$ | $X^T @ dY$ |
| $db$ | sum($dY$, axis=0) |
| $dX$ | $dY @ W^T$ |
计算量分析
矩阵乘法 FLOPs = $2 \times a \times b \times c$
| 阶段 | 计算 | FLOPs |
|---|---|---|
| 前向 | $Y = X @ W$ | 2NPM |
| 反向 $dW$ | $X^T @ dY$ | 2NPM |
| 反向 $dX$ | $dY @ W^T$ | 2NPM |
| 训练总计 | 前向 + 反向 | 6NPM |
| 推理 | 仅前向 | 2NPM |
结论:训练计算量 ≈ 3 倍推理。
PyTorch 实现
| |
常见问题
为什么必须算 $dX$? 它是误差传递的唯一桥梁,中间层必须算。
为什么反向比前向慢? 需保存中间结果,显存读写开销大。
梯度消失/爆炸:梯度逐层相乘,指数级缩小或增大。LayerNorm、残差连接可缓解。
张芷铭的个人博客
Comments