DiT(Diffusion Transformer)将 Transformer 作为扩散模型骨干网络替代传统 U-Net,利用全局注意力机制和卓越扩展性,显著提升图像与视频生成的质量和效率。
核心思想
DiT 通过 Transformer 替代传统 U-Net,克服 U-Net 的三大局限:
- 扩展瓶颈:U-Net 增大模型时性能回报递减
- 架构割裂:与 NLP 等领域主流架构不统一
- 全局依赖建模不足:卷积操作感受野有限
架构设计
整体框架
建立在 LDM(Latent Diffusion Model) 框架上:
- 编码阶段:VAE 编码器将图像压缩至潜在空间
- 扩散阶段:在潜在空间执行扩散过程
- 解码阶段:VAE 解码器恢复为像素空间
核心组件
Patchify 模块
将空间表示转换为 Transformer 可处理的序列,patch_size 决定 token 数量 $T=(H/p)×(W/p)$。
条件注入机制
| 机制 | 特点 |
|---|
| In-Context Conditioning | 将时间步和类别嵌入作为额外 token 拼接 |
| Cross-Attention | 在自注意力后添加条件信息交叉注意力层 |
| adaLN-Zero | 动态生成 LayerNorm 参数,初始化为恒等函数,最优方案 |
adaLN-Zero 实现
1
2
3
4
5
6
| class DiTBlock(nn.Module):
def forward(self, x, c):
gamma1, beta1, alpha1, gamma2, beta2, alpha2 = self.adaLN_modulation(c).chunk(6, 1)
x = x + alpha1 * self.attn(gamma1 * self.norm1(x) + beta1)[0]
x = x + alpha2 * self.mlp(gamma2 * self.norm2(x) + beta2)
return x
|
可扩展性设计
| 维度 | 扩展方式 |
|---|
| 深度 | DiT-S(12 层)→ DiT-XL(28 层) |
| 宽度 | 隐藏维度从 384(S)到 1152(XL) |
| Token 数量 | 减少 patch 尺寸(p=8→2)增加序列长度 |
条件机制性能对比
| 机制 | FID↓ | 特点 |
|---|
| In-Context | 5.21 | 参数量最小 |
| Cross-Attention | 4.58 | 参数增加 20% |
| adaLN | 3.75 | 参数不变 |
| adaLN-Zero | 2.27 | 初始化为恒等函数,训练稳定 |
扩展与变体
| 模型 | 特点 |
|---|
| U-ViT | 融合 U-Net 跳跃连接,所有中间层特征残差聚合 |
| MDT | 掩码潜在建模,训练时随机 mask 30-50% patch token |
| DiffiT | 结合 U-Net 层级结构与 Transformer,Time-dependent Self-Attention |
应用案例
Sora 中的 DiT 实现
- VAE 编码器:压缩视频帧至潜在空间
- ViT 分词器:将时空块转换为 token 序列
- DiT 主干:在扩散过程中处理噪声预测
图像生成示例
1
2
3
4
5
6
7
| from diffusers import DiTPipeline, DPMSolverMultistepScheduler
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
input_ids = pipe.get_label_ids(["strawberry", "cat"])
image = pipe(input_ids=input_ids).images[0]
|
未来方向
- 多模态对齐:文本-图像-视频统一 DiT 框架
- 3D 生成:扩展时空块处理能力
- 自监督学习:结合 MAE 等预训练策略
- 硬件协同设计:针对 Transformer 特性优化芯片架构
Comments