张芷铭的个人博客

Diffusion Model 基于马尔可夫链实现逐步加噪与去噪,已成为生成式 AI 的核心引擎。

定义与发展历程

Diffusion Model 是一类基于马尔可夫链的生成式模型,通过逐步添加噪声破坏数据分布,再学习逆向去噪过程重建数据。受非平衡统计物理学启发,该模型通过前向扩散系统性破坏数据结构,再通过反向扩散恢复结构。

发展里程碑

阶段模型贡献
奠基DDPM建立加噪-去噪基本范式
效率突破DDIM确定性采样加速,生成速度提升 10×
跨模态演进Stable Diffusion潜在空间操作,显著降低计算开销
工业应用GLIDE文本引导图像生成

核心原理

前向扩散过程

将原始数据 $X_0$ 逐步转化为高斯噪声:

$$X_t = \sqrt{\alpha_t}X_{t-1} + \sqrt{1-\alpha_t}Z_t, \quad Z_t \sim \mathcal{N}(0,I)$$

闭式解:

$$q(X_t|X_0) = \mathcal{N}(X_t; \sqrt{\bar{\alpha}_t}X_0, (1-\bar{\alpha}_t)I)$$

其中 $\bar{\alpha}t = \prod{i=1}^t \alpha_i$,$\alpha_t = 1 - \beta_t$,$\beta_t$ 为噪声调度系数。

逆向去噪过程

学习映射 $p_\theta(X_{t-1}|X_t)$ 以重建数据。通过变分推断优化变分下界(ELBO):

$$\mathcal{L}{\text{VLB}} = \mathbb{E}q \left[ \log \frac{q(X{1:T}|X_0)}{p\theta(X_{0:T})} \right]$$

分解为逐时间步的 KL 散度项:

$$\mathcal{L}t = D{\text{KL}}\left( q(X_t|X_{t+1},X_0) \parallel p_\theta(X_t|X_{t+1}) \right)$$

训练目标简化

通过参数重整化,目标简化为噪声预测任务

$$\mathcal{L}{\text{simple}} = \mathbb{E}{t,X_0,\epsilon} \left[ | \epsilon - \epsilon_\theta(X_t,t) |^2 \right]$$

U-Net 模型 $\epsilon_\theta$ 学习预测添加的噪声。

关键技术

噪声调度策略

策略公式优势局限
线性调度$\beta_t = \beta_{\min} + (\beta_{\max} - \beta_{\min})\frac{t}{T}$实现简单噪声增减不均衡
余弦调度$\bar{\alpha}_t = \frac{\cos(t/T \cdot \pi/2)}{\cos(\pi/2)}$平滑过渡,保留细节计算复杂度较高

U-Net 架构改进

1
2
3
4
5
6
7
class UNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        self.down1 = DownsampleBlock(64)
        self.attn1 = SelfAttentionBlock(128)
        self.up1 = UpsampleBlock(256)
        self.conv_out = nn.Conv2d(64, 3, kernel_size=1)

核心改进:

  • 残差块替换为自注意力块
  • 时间步嵌入融入各层
  • 跳跃连接保留空间信息

应用场景

领域案例技术亮点
图像生成Stable Diffusion潜在空间扩散,512×512 生成仅需 2 秒
图像编辑GLIDE文本引导局部编辑
视频生成Make-A-Video时间维度扩散,帧间一致性保持
科学计算AlphaFold3蛋白质结构扩散生成

最新进展

EDM2 架构

EDM2 在 ImageNet-512 上 FID=1.81,模型缩小 5 倍仍保持 SOTA:

  • 激活值保持:强制每层输入/输出激活值范数不变
  • 组归一化移除:简化网络结构
  • 偏置项消除:提升训练稳定性

一致性模型

  • 单步生成:扩散轨迹映射为 ODE,蒸馏实现一步采样
  • 零样本编辑:预训练模型实现图像修复、插值

多模态融合

  • CLIP 引导:文本编码器与扩散模型联合训练
  • 3D 扩散:NeRF + Diffusion 实现三维场景生成

采样代码示例

1
2
3
4
5
6
7
8
9
def ddpm_sampling(model, noise, T, alpha_bars):
    x = noise
    for t in range(T, 0, -1):
        z = torch.randn_like(x) if t > 1 else 0
        eps = model(x, t)
        x = (1 / torch.sqrt(alpha_bars[t])) * \
             (x - (1 - alpha_bars[t]) / torch.sqrt(1 - alpha_bars[t]) * eps) + \
             torch.sqrt(1 - alpha_bars[t]) * z
    return x

扩展阅读

Comments