张芷铭的个人博客

训推一致(Train-Inference Consistency)指训练与推理阶段的数据、计算、分布、精度全链路完全等价,是算法落地的核心前提。

核心定义

训推一致:相同输入 → 训练与推理输出数值误差≤1e-4

要求说明
数据一致预处理流程、分布、尺度完全相同
计算一致网络层计算逻辑、统计量、参数、激活方式相同
精度一致数值精度、算子实现、运行模式相同

训练vs推理天生差异

阶段核心目标允许行为
训练学习分布、更新参数随机正则、Batch统计、梯度回传
推理稳定输出预测固定参数、无随机、无梯度

训推一致的本质:抹平所有改变特征分布的差异

为什么极端重要

层次原因
浅层模型只认识训练时的分布,推理分布变了直接失效
中层微小误差被深层网络指数级放大,100层叠加偏移达100%+
深层生成式模型迭代特性,误差累积放大,一步错步步错

本质原理:推理时归一化逻辑变化引发二次协变量偏移,让训练时归一化优化失效。

不一致核心场景

数据预处理不一致

训练推理用两套预处理代码,输入分布完全偏移。

准则:必须调用同一个预处理函数。

归一化层行为不一致

归一化类型训练推理不一致后果
BN用当前Batch均值/方差用训练累积滑动平均特征分布剧烈偏移
LN计算逻辑一致参数初始化、eps需一致相对安全
AdaLN条件向量→MLP→动态γ,β,α必须复用相同MLPDiT生成畸变

随机正则化开关不一致

阶段Dropout/DropPath
训练开启
推理必须关闭

模型运行模式不一致

1
2
3
4
# 推理时必写
model.eval()
with torch.no_grad():
    output = model(input)

数值精度不一致

训练FP32/BF16,推理FP16/INT8,舍入误差影响AdaLN小数值计算。

条件嵌入不一致(DiT专属)

时间步嵌入→LayerNorm→SiLU→映射为γ/β,推理时少一层Norm或激活函数错误,是文生图上线效果差的第一原因。

模型敏感度对比

模型类型敏感度核心痛点
传统CNN★★★★☆BN层训推不一致
Transformer/LLM★★★★★预处理、位置编码、精度、Mask逻辑
DiT/扩散/AdaLN★★★★★★AdaLN动态参数、时间步嵌入、去噪迭代一致性

工程保障方案

层次方法
代码层预处理、网络前向、AdaLN映射共用一套代码,禁止重写
模式层强制model.eval(),冻结统计量
归一化层BN冻结滑动均值/方差,AdaLN条件处理流程完全对齐
精度层训练推理精度对齐,量化需校准
校验层同样本逐层对比输出,误差<1e-4

总结

  1. 训推一致 = 训练与推理数据、计算、分布、精度全链路等价
  2. 越深的模型、生成模型(DiT/LLM),对一致性要求越苛刻
  3. AdaLN/DiT场景下,条件嵌入一致性是生成效果的生命线
  4. 工程唯一解:代码复用 + 强制eval + 逐层校验

Comments