张芷铭的个人博客

DiT(Diffusion Transformer)将 Transformer 作为扩散模型骨干网络替代传统 U-Net,利用全局注意力机制和卓越扩展性,显著提升图像与视频生成的质量和效率。

核心思想

DiT 通过 Transformer 替代传统 U-Net,克服 U-Net 的三大局限:

  1. 扩展瓶颈:U-Net 增大模型时性能回报递减
  2. 架构割裂:与 NLP 等领域主流架构不统一
  3. 全局依赖建模不足:卷积操作感受野有限

架构设计

整体框架

建立在 LDM(Latent Diffusion Model) 框架上:

  1. 编码阶段:VAE 编码器将图像压缩至潜在空间
  2. 扩散阶段:在潜在空间执行扩散过程
  3. 解码阶段:VAE 解码器恢复为像素空间

核心组件

Patchify 模块

将空间表示转换为 Transformer 可处理的序列,patch_size 决定 token 数量 $T=(H/p)×(W/p)$。

条件注入机制

机制特点
In-Context Conditioning将时间步和类别嵌入作为额外 token 拼接
Cross-Attention在自注意力后添加条件信息交叉注意力层
adaLN-Zero动态生成 LayerNorm 参数,初始化为恒等函数,最优方案

adaLN-Zero 实现

1
2
3
4
5
6
class DiTBlock(nn.Module):
    def forward(self, x, c):
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = self.adaLN_modulation(c).chunk(6, 1)
        x = x + alpha1 * self.attn(gamma1 * self.norm1(x) + beta1)[0]
        x = x + alpha2 * self.mlp(gamma2 * self.norm2(x) + beta2)
        return x

可扩展性设计

维度扩展方式
深度DiT-S(12 层)→ DiT-XL(28 层)
宽度隐藏维度从 384(S)到 1152(XL)
Token 数量减少 patch 尺寸(p=8→2)增加序列长度

条件机制性能对比

机制FID↓特点
In-Context5.21参数量最小
Cross-Attention4.58参数增加 20%
adaLN3.75参数不变
adaLN-Zero2.27初始化为恒等函数,训练稳定

扩展与变体

模型特点
U-ViT融合 U-Net 跳跃连接,所有中间层特征残差聚合
MDT掩码潜在建模,训练时随机 mask 30-50% patch token
DiffiT结合 U-Net 层级结构与 Transformer,Time-dependent Self-Attention

应用案例

Sora 中的 DiT 实现

  1. VAE 编码器:压缩视频帧至潜在空间
  2. ViT 分词器:将时空块转换为 token 序列
  3. DiT 主干:在扩散过程中处理噪声预测

图像生成示例

1
2
3
4
5
6
7
from diffusers import DiTPipeline, DPMSolverMultistepScheduler

pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

input_ids = pipe.get_label_ids(["strawberry", "cat"])
image = pipe(input_ids=input_ids).images[0]

未来方向

  • 多模态对齐:文本-图像-视频统一 DiT 框架
  • 3D 生成:扩展时空块处理能力
  • 自监督学习:结合 MAE 等预训练策略
  • 硬件协同设计:针对 Transformer 特性优化芯片架构

Comments