扩散模型损失函数以噪声预测 L2 损失(离散时间)和分数匹配损失(连续时间)为核心,其他损失多为其变体或扩展。
基础噪声预测损失
噪声预测 L2 损失(DDPM 核心)
$$\mathcal{L}{\text{simple}} = \mathbb{E}{x_0, \epsilon, t} \left[ |\epsilon - \epsilon_\theta(x_t, t)|^2 \right]$$
模型预测当前时间步样本中包含的噪声,计算简单、训练稳定。
分数匹配损失(连续时间模型)
$$\mathcal{L}{\text{score}} = \mathbb{E}{x_t, t} \left[ |s(x_t, t) - s_\theta(x_t, t)|^2 \right]$$
学习分数函数 $s_\theta(x_t, t) = -\nabla_{x_t} \log p_t(x_t)$,适用于 SDE 框架。
扩展变体损失
加权噪声预测损失
$$\mathcal{L}{\text{weighted}} = \mathbb{E}{x_0, \epsilon, t} \left[ w(t) \cdot |\epsilon - \epsilon_\theta(x_t, t)|^2 \right]$$
对不同时间步分配不同权重,优化关键步骤学习效果。
样本预测损失
| 预测目标 | 公式 |
|---|---|
| $x_{t-1}$ | $\mathcal{L} = \mathbb{E} \left[ |x_{t-1} - x_{t-1,\theta}(x_t, t)|^2 \right]$ |
| $x_0$ | $\mathcal{L} = \mathbb{E} \left[ |x_0 - x_{0,\theta}(x_t, t)|^2 \right]$ |
感知损失(LPIPS)
$$\mathcal{L}{\text{LPIPS}} = | \phi(\epsilon\theta) - \phi(\epsilon) |_2$$
使用预训练 VGG 网络提取特征,计算特征空间距离,提升视觉质量。
特定任务优化损失
条件对齐损失
$$\mathcal{L}_{\text{align}} = \mathbb{E} \left[ 1 - \text{sim}(\phi(x_0), \psi(y)) \right]$$
确保生成样本与条件信息的语义对齐。
对抗损失
$$\mathcal{L}_{\text{adv}} = \mathbb{E} \left[ \log D(x_0) + \log(1 - D(\hat{x}_0)) \right]$$
结合 GAN 对抗训练,提升样本细节质量。
分类器引导损失
$$\mathcal{L}_{\text{classifier}} = \mathbb{E} \left[ \text{CE}(y, \text{Classifier}(\hat{x}_0)) \right]$$
通过分类器梯度引导生成过程。
视频扩散模型专用损失
RFLoss
- 动态噪声调度:根据帧间时序依赖关系调整噪声权重
- 分布式计算优化:多 GPU 并行处理长视频序列
- 多类型损失兼容:同时支持噪声预测损失、感知损失
总结
| 损失类型 | 核心应用 |
|---|---|
| L2 噪声预测 | 离散时间扩散模型基础 |
| 分数匹配 | 连续时间 SDE 框架 |
| 感知损失 | 提升视觉质量 |
| 条件对齐 | 多模态生成 |
| 对抗损失 | 细节增强 |
实际应用中常组合使用多种损失,平衡生成质量、训练效率和任务适配性。
张芷铭的个人博客
Comments