张芷铭的个人博客

混合精度训练结合FP32和FP16浮点数,在保持精度的同时降低50%显存占用、提升2-3倍训练速度。

浮点格式对比

类型符号位指数位分数位数值范围精度
FP321823±3.4×10^38~6e-8
FP161510±65504~5e-4
BF16187±3.4×10^38~1e-2

FP16局限性

问题说明
数值溢出梯度<6e-8下溢,>65504上溢
舍入误差更新量过小时FP16无法精确表示

核心技术

权重备份

  • 维护FP32主权重副本
  • FP16用于前向/反向计算,参数更新在FP32空间完成

损失缩放

  • 前向放大损失:$\mathcal{L}’ = \mathcal{L} \times S$
  • 反向还原梯度:$g’ = g/S$

精度累加

1
2
3
with autocast():
    output = model(inputs)  # FP16计算
loss = criterion(output, targets)  # FP32累加

PyTorch实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for inputs, targets in dataloader:
    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

性能优化技巧

技巧说明
维度对齐Tensor维度设为8的倍数(适配Tensor Core)
算子选择优先matmul,避免einsum
混合策略O1模式:黑白名单自动选择精度

常见问题

现象原因解决方案
Loss出现NaN梯度爆炸减小缩放因子,添加梯度裁剪
训练不收敛权重更新失效检查权重备份机制
显存未降低静态内存占用高torch.cuda.empty_cache()

硬件要求

  • NVIDIA Volta+架构(V100/A100/3090等)
  • V100 FP16: 125 TFLOPS vs FP32: 15.7 TFLOPS
  • A100 FP16: 312 TFLOPS vs TF32: 156 TFLOPS

Comments