扩散模型的训练损失(Loss)设计与其核心目标——学习“去噪过程”密切相关,核心是通过优化损失函数让模型学会预测不同时间步的噪声或还原前一时间步的样本。常见的损失函数可分为基础噪声预测损失、扩展变体损失和特定任务优化损失三大类,具体如下:
一、基础噪声预测损失(核心损失)
扩散模型的核心是通过学习“从加噪样本中预测原始噪声”来优化去噪过程,最基础的损失函数均围绕这一目标设计。
1. 噪声预测的L2损失(DDPM核心损失)
- 原理:在离散时间扩散模型(如DDPM)中,模型的核心任务是预测“当前时间步样本中包含的噪声”。对于原始样本$x_0$和随机噪声$\epsilon \sim \mathcal{N}(0,I)$,任意时间步$t$的加噪样本可表示为$x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}t}\epsilon$。模型$\epsilon\theta(x_t, t)$需预测真实噪声$\epsilon$,损失定义为两者的L2距离。
- 公式:
$$ \mathcal{L}{\text{simple}} = \mathbb{E}{x_0, \epsilon, t} \left[ |\epsilon - \epsilon_\theta(x_t, t)|^2 \right] $$ - 特点:计算简单、训练稳定,是扩散模型最基础的损失函数,后续多数变体均基于此扩展。
2. 分数匹配损失(连续时间模型核心损失)
- 原理:在连续时间扩散模型(基于随机微分方程SDE)中,模型需学习“分数函数”$s_\theta(x_t, t) = -\nabla_{x_t} \log p_t(x_t)$(即数据分布的对数密度梯度),分数函数可用于指导去噪过程。损失函数通过最小化模型预测的分数与真实分数的差异设计。
- 公式:
$$ \mathcal{L}{\text{score}} = \mathbb{E}{x_t, t} \left[ |s(x_t, t) - s_\theta(x_t, t)|^2 \right] $$ 其中$s(x_t, t)$是真实分数(可通过前向过程推导),$s_\theta$是模型预测的分数。 - 特点:适用于连续时间框架,与SDE的数学原理直接对应,常见于Score-Based Generative Models。
二、扩展变体损失(基于基础损失的优化)
为提升生成质量、训练效率或适配特定场景,研究者在基础损失上进行了扩展。
1. 加权噪声预测损失
- 原理:原始L2损失对所有时间步$t$赋予相同权重,但不同时间步的去噪难度可能不同(如高噪声步$t \approx T$和低噪声步$t \approx 0$)。加权损失通过对不同时间步分配不同权重,提升关键步骤的学习效果。
- 公式:
$$ \mathcal{L}{\text{weighted}} = \mathbb{E}{x_0, \epsilon, t} \left[ w(t) \cdot |\epsilon - \epsilon_\theta(x_t, t)|^2 \right] $$ 其中$w(t)$是时间步$t$的权重(如$w(t) = 1/\bar{\alpha}_t$或学习得到的参数)。 - 应用:DDPM论文中曾尝试对高时间步赋予更高权重,部分后续工作(如改进采样效率的模型)通过加权优化提升低噪声步的预测精度。
2. 样本预测损失(而非噪声预测)
- 原理:部分模型不直接预测噪声,而是预测前一时间步的样本$x_{t-1}$或原始样本$x_0$,损失定义为预测样本与真实样本的差异。
- 公式:
- 预测$x_{t-1}$:$\mathcal{L} = \mathbb{E} \left[ |x_{t-1} - x_{t-1,\theta}(x_t, t)|^2 \right]$
- 预测$x_0$:$\mathcal{L} = \mathbb{E} \left[ |x_0 - x_{0,\theta}(x_t, t)|^2 \right]$
- 特点:与噪声预测损失等价(因$x_{t-1}$和$x_0$可通过噪声推导),但更直观反映“样本还原”目标。
3. 感知损失(Perceptual Loss)
- 原理:为提升生成样本的视觉质量(如纹理、语义一致性),引入预训练视觉模型(如VGG、CLIP)的特征作为损失度量,而非直接使用像素级L2损失。
- 公式:
$$ \mathcal{L}_{\text{perceptual}} = \mathbb{E} \left[ | \phi(x_0) - \phi(\hat{x}_0) |^2 \right] $$ 其中$\phi(\cdot)$是预训练模型的特征提取器,$\hat{x}_0$是模型生成的样本。 - 应用:广泛用于图像生成任务(如Stable Diffusion的扩展模型),提升样本的高层语义一致性。
三、特定任务优化损失
针对条件生成、多模态融合等复杂任务,需设计与任务绑定的损失函数。
1. 条件对齐损失(Conditional Alignment Loss)
- 原理:在条件生成任务(如文本到图像)中,需确保生成样本与条件信息(如文本描述)的语义对齐。通过引入跨模态相似度损失,强制生成样本的特征与条件特征匹配。
- 公式:
$$ \mathcal{L}_{\text{align}} = \mathbb{E} \left[ 1 - \text{sim}(\phi(x_0), \psi(y)) \right] $$ 其中$y$是条件信息(如文本),$\phi(\cdot)$和$\psi(\cdot)$分别是图像和文本的特征提取器(如CLIP的视觉/文本编码器),$\text{sim}$是余弦相似度。 - 应用:DALL-E 2、Stable Diffusion等文本到图像模型通过此类损失增强文本与图像的语义一致性。
2. 对抗损失(Adversarial Loss)
- 原理:结合GAN的对抗训练思想,引入判别器区分“真实样本”和“扩散模型生成的样本”,通过对抗损失提升生成样本的真实性。
- 公式:
$$ \mathcal{L}_{\text{adv}} = \mathbb{E} \left[ \log D(x_0) + \log(1 - D(\hat{x}_0)) \right] $$ 其中$D$是判别器,$\hat{x}_0$是扩散模型生成的样本。 - 特点:可提升样本细节质量,但可能增加训练不稳定性(需平衡扩散损失与对抗损失)。
3. 分类器引导损失(Classifier-Guided Loss)
- 原理:在类别条件生成中,若有预训练分类器,可通过分类器的梯度引导生成过程,损失函数包含分类器对生成样本的类别预测误差。
- 公式:
$$ \mathcal{L}_{\text{classifier}} = \mathbb{E} \left[ \text{CE}(y, \text{Classifier}(\hat{x}_0)) \right] $$ 其中$\text{CE}$是交叉熵损失,$y$是样本类别标签。 - 替代方案:Classifier-Free Guidance(无需外部分类器,通过训练时随机丢弃条件信息实现引导),本质是对条件/无条件损失的差值加权,避免额外训练分类器。
总结
扩散模型的损失函数以噪声预测的L2损失(离散时间)和分数匹配损失(连续时间)为核心,其他损失多为其变体或扩展。实际应用中,常根据任务需求组合使用(如基础噪声损失+感知损失+条件对齐损失),以平衡生成质量、训练效率和任务适配性。
RFLoss 是一种专为视频扩散模型设计的噪声预测损失函数,通过动态调整噪声调度、分布式计算优化和多类型损失支持,提升模型在复杂序列数据上的训练效果。以下从数学原理和代码实现两个维度详细解析:
- LPIPS(感知损失):
- 相关博客:LPIPS 图像相似性度量标准、感知损失(Perceptual loss) 使用预训练的VGG网络提取特征,计算特征空间距离: $$ \mathcal{L}{\text{LPIPS}} = | \phi(\epsilon\theta) - \phi(\epsilon) |_2 $$ 适合提升生成结果的视觉质量
💬 评论