张芷铭的个人博客

Diffusion Transformers(DiT)融合 Transformer 架构与扩散模型,通过全局建模能力和卓越扩展性重塑图像与视频生成范式。

概述

DiT 是一种将 Transformer 架构与扩散模型相结合的生成式模型。通过替换传统 U-Net 主干,利用 Transformer 的全局建模能力显著提升生成质量和效率。

传统扩散模型核心:

  • 正向过程:$q(\mathbf{x}t|\mathbf{x}{t-1})=\mathcal{N}(\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t \mathbf{I})$
  • 逆向过程:$p_\theta(\mathbf{x}{t-1}|\mathbf{x}t)=\mathcal{N}(\mathbf{x}{t-1};\mu\theta(\mathbf{x}t,t),\Sigma\theta)$

DiT 克服 U-Net 三大局限:扩展瓶颈、架构割裂、全局依赖建模不足。

架构设计

整体框架

建立在 Latent Diffusion Model(LDM)框架上:

  1. 编码阶段:VAE 编码器将图像压缩至潜在空间
  2. 扩散阶段:在潜在空间执行扩散过程
  3. 解码阶段:VAE 解码器恢复为像素空间

核心组件

Patchify 模块

将空间表示转换为 token 序列:

1
2
3
4
5
6
7
8
9
class Patchify(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim,
                            kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B,C,H,W) -> (B,d,H/p,W/p)
        return x.flatten(2).permute(0,2,1)  # -> (B,N,d)

条件注入机制

DiT 探索三种方式:

  1. In-Context Conditioning:时间步和类别嵌入作为额外 token
  2. Cross-Attention:自注意力后添加交叉注意力层
  3. Adaptive Layer Norm (adaLN):动态生成 LayerNorm 参数

adaLN-Zero 被证明最优:

1
2
3
4
5
6
class DiTBlock(nn.Module):
    def forward(self, x, c):
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = self.adaLN_modulation(c).chunk(6, 1)
        x = x + alpha1 * self.attn(gamma1 * self.norm1(x) + beta1)[0]
        x = x + alpha2 * self.mlp(gamma2 * self.norm2(x) + beta2)
        return x

初始化为恒等函数(α 初始为 0),确保训练稳定性。

训练与优化

可扩展性设计

通过三维度实现扩展:

  • 深度:DiT-S(12 层)→ DiT-XL(28 层)
  • 宽度:隐藏维度从 384 到 1152
  • Token 数量:减少 patch 尺寸增加序列长度

DiT-XL/2 在 ImageNet 256×256 达到 FID 2.27,超越所有 U-Net 扩散模型。

条件机制对比

机制FID↓训练速度参数量
In-Context5.211.0x最小
Cross-Attention4.580.85x增加 20%
adaLN3.750.95x不变
adaLN-Zero2.270.98x微增

扩展变体

模型特点
U-ViT融合跳跃连接,中间层特征残差聚合
MDT掩码潜在建模增强语义学习
DiffiTU-Net 层级结构 + Time-dependent Self-Attention

实战应用

Sora 中的 DiT

核心组件:

  1. VAE 编码器:压缩视频帧至潜在空间
  2. ViT 分词器:时空块转换为 token 序列
  3. DiT 主干:噪声预测

图像生成示例

1
2
3
4
5
6
from diffusers import DiTPipeline, DPMSolverMultistepScheduler

pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
input_ids = pipe.get_label_ids(["strawberry", "cat"])
image = pipe(input_ids=input_ids).images[0]

未来方向

  • 多模态对齐:文本-图像-视频统一 DiT 框架
  • 3D 生成:扩展时空块处理能力
  • 自监督学习:结合 MAE 等预训练策略

Comments