训推一致(Train-Inference Consistency)指训练与推理阶段的数据、计算、分布、精度全链路完全等价,是算法落地的核心前提。
核心定义
训推一致:相同输入 → 训练与推理输出数值误差≤1e-4
| 要求 | 说明 |
|---|---|
| 数据一致 | 预处理流程、分布、尺度完全相同 |
| 计算一致 | 网络层计算逻辑、统计量、参数、激活方式相同 |
| 精度一致 | 数值精度、算子实现、运行模式相同 |
训练vs推理天生差异
| 阶段 | 核心目标 | 允许行为 |
|---|---|---|
| 训练 | 学习分布、更新参数 | 随机正则、Batch统计、梯度回传 |
| 推理 | 稳定输出预测 | 固定参数、无随机、无梯度 |
训推一致的本质:抹平所有改变特征分布的差异。
为什么极端重要
| 层次 | 原因 |
|---|---|
| 浅层 | 模型只认识训练时的分布,推理分布变了直接失效 |
| 中层 | 微小误差被深层网络指数级放大,100层叠加偏移达100%+ |
| 深层 | 生成式模型迭代特性,误差累积放大,一步错步步错 |
本质原理:推理时归一化逻辑变化引发二次协变量偏移,让训练时归一化优化失效。
不一致核心场景
数据预处理不一致
训练推理用两套预处理代码,输入分布完全偏移。
准则:必须调用同一个预处理函数。
归一化层行为不一致
| 归一化类型 | 训练 | 推理 | 不一致后果 |
|---|---|---|---|
| BN | 用当前Batch均值/方差 | 用训练累积滑动平均 | 特征分布剧烈偏移 |
| LN | 计算逻辑一致 | 参数初始化、eps需一致 | 相对安全 |
| AdaLN | 条件向量→MLP→动态γ,β,α | 必须复用相同MLP | DiT生成畸变 |
随机正则化开关不一致
| 阶段 | Dropout/DropPath |
|---|---|
| 训练 | 开启 |
| 推理 | 必须关闭 |
模型运行模式不一致
| |
数值精度不一致
训练FP32/BF16,推理FP16/INT8,舍入误差影响AdaLN小数值计算。
条件嵌入不一致(DiT专属)
时间步嵌入→LayerNorm→SiLU→映射为γ/β,推理时少一层Norm或激活函数错误,是文生图上线效果差的第一原因。
模型敏感度对比
| 模型类型 | 敏感度 | 核心痛点 |
|---|---|---|
| 传统CNN | ★★★★☆ | BN层训推不一致 |
| Transformer/LLM | ★★★★★ | 预处理、位置编码、精度、Mask逻辑 |
| DiT/扩散/AdaLN | ★★★★★★ | AdaLN动态参数、时间步嵌入、去噪迭代一致性 |
工程保障方案
| 层次 | 方法 |
|---|---|
| 代码层 | 预处理、网络前向、AdaLN映射共用一套代码,禁止重写 |
| 模式层 | 强制model.eval(),冻结统计量 |
| 归一化层 | BN冻结滑动均值/方差,AdaLN条件处理流程完全对齐 |
| 精度层 | 训练推理精度对齐,量化需校准 |
| 校验层 | 同样本逐层对比输出,误差<1e-4 |
总结
- 训推一致 = 训练与推理数据、计算、分布、精度全链路等价
- 越深的模型、生成模型(DiT/LLM),对一致性要求越苛刻
- AdaLN/DiT场景下,条件嵌入一致性是生成效果的生命线
- 工程唯一解:代码复用 + 强制eval + 逐层校验
张芷铭的个人博客
Comments