MMDiT(Multi-Modal Diffusion Transformer)是 Stable Diffusion 3 的骨干架构。在 DiT 基础上引入双路并行 Transformer + 双向交叉注意力,让图像与文本模态在每一层深度互动,配合 Rectified Flow 实现训练稳定与极少步采样。
演进脉络
| 阶段 | 关键节点 |
|---|---|
| 2017 | Transformer 诞生 |
| 2020 | DDPM 奠定现代扩散 |
| 2022 | Meta 提出 DiT:纯 Transformer + adaLN-Zero,打破 U-Net 缩放瓶颈 |
| 2024 | Stability AI 在 SD3 落地 MMDiT:双路 + 双向 cross-attention + Rectified Flow |
| 2024-2026 | 扩展到视频/3D/多模态联合生成;轻量化与极少步采样工业落地 |
Rectified Flow(前置基础)
整流流是 ODE 流生成模型,比 DDPM 训练更稳、采样更短。
线性插值路径:
为真实样本,。
学习速度场 拟合 ,损失:
采样:从 出发求解 ODE 反向积分到 。
优势:训练稳定,采样可压缩到 1–4 步,无需复杂噪声调度器。
adaLN-Zero(前置基础)
DiT 用全局条件(时间步 + 类别/文本嵌入)生成 LayerNorm 的缩放、偏移、门控参数:
调制后的 Transformer 块:
Zero 初始化:调制层权重置零 → 训练初期为恒等映射 → 深层稳定收敛。
MMDiT 核心:双路 + 双向 cross
输入输出
| 项 | 形状 / 说明 |
|---|---|
| 图像 latent | ,VAE 16× 下采样,C=16(SD3) |
| 时间步嵌入 | ,正弦位置编码 |
| 文本嵌入 | ,T5 / CLIP 编码器 |
| 输出 | 同 形状,速度场 |
单个双路 Transformer 块(4 步)
1) 全局条件调制:图像/文本各自独立的 adaLN 参数
文本支路同理。
2) 模态内自注意力(图像/文本各做一次 MHSA)
3) 模态间双向交叉注意力(MMDiT 的核心创新)
图像 → 文本:Q=img, KV=text ← 图像感知文本
文本 → 图像:Q=text, KV=img ← 文本感知图像不同于传统单向 cross,MMDiT 让两个模态深度互相感知。
4) FFN + 残差(图像/文本各一次,沿用 adaLN 调制)
输出层
堆叠 N 个双路块后,对图像支路再做一次 adaLN + 投影,恢复 。
六大优势
- 缩放性:模型从 800M → 8B 参数线性提升,无明显饱和
- 多模态对齐:双向 cross 解决”提示词不匹配/多物体遗漏/长文本理解”
- 训练/采样稳定:adaLN-Zero + Rectified Flow,最少 1 步采样
- 可扩展模态:新增模态只需增加支路与 cross 模块,已扩到深度图、草图、音频、视频、3D
- 长距离建模:全局注意力解决肢体结构、复杂布局、全局光影一致性
- 微调友好:兼容 LoRA / QLoRA / DoRA 等低秩方案
应用场景
| 类型 | 代表 |
|---|---|
| 文生图 | SD3 系列,4K 原生生成 |
| 图像编辑 | Inpainting / Outpainting / 风格迁移 |
| 可控生成 | 文本 + 深度图/草图/姿态/语义图,像素级精控 |
| 文生视频 | Stable Video 3,时空注意力 |
| 3D 生成 | 3D latent 输入,文本驱动 3D 模型 |
| 多模态联合 | 文/图/音/视频联合生成,通用生成式 AI 底座 |
| 垂直领域 | 医疗影像、遥感、工业缺陷检测 |
极简 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.out_proj = nn.Linear(dim, dim)
def forward(self, q, kv=None):
kv = q if kv is None else kv
q = rearrange(self.q_proj(q), "B N (H D) -> B H N D", H=self.num_heads)
k = rearrange(self.k_proj(kv), "B N (H D) -> B H N D", H=self.num_heads)
v = rearrange(self.v_proj(kv), "B N (H D) -> B H N D", H=self.num_heads)
attn = F.softmax(q @ k.transpose(-2, -1) * self.scale, dim=-1)
out = rearrange(attn @ v, "B H N D -> B N (H D)")
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or dim * 4
self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(),
nn.Linear(hidden_dim, dim))
def forward(self, x):
return self.net(x)
class MMDiTBlock(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
self.modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 8 * dim))
self.img_self_attn = Attention(dim, num_heads)
self.text_self_attn = Attention(dim, num_heads)
self.img_cross_attn = Attention(dim, num_heads)
self.text_cross_attn = Attention(dim, num_heads)
self.img_ffn = FeedForward(dim)
self.text_ffn = FeedForward(dim)
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
def forward(self, img_feat, text_feat, t_cond):
params = self.modulation(t_cond).chunk(8, dim=-1)
ig, ib, ia, iaf, tg, tb, ta, taf = params
# 1) 模态内自注意力
img_feat = img_feat + ia.unsqueeze(1) * self.img_self_attn (self.norm(img_feat) * ig.unsqueeze(1) + ib.unsqueeze(1))
text_feat = text_feat + ta.unsqueeze(1) * self.text_self_attn(self.norm(text_feat) * tg.unsqueeze(1) + tb.unsqueeze(1))
# 2) 双向交叉注意力
img_feat = img_feat + self.img_cross_attn (q=self.norm(img_feat), kv=text_feat)
text_feat = text_feat + self.text_cross_attn(q=self.norm(text_feat), kv=img_feat)
# 3) FFN
img_feat = img_feat + iaf.unsqueeze(1) * self.img_ffn (self.norm(img_feat) * ig.unsqueeze(1) + ib.unsqueeze(1))
text_feat = text_feat + taf.unsqueeze(1) * self.text_ffn(self.norm(text_feat) * tg.unsqueeze(1) + tb.unsqueeze(1))
return img_feat, text_feat
class MMDiT(nn.Module):
def __init__(self, dim=1024, num_heads=16, num_blocks=24, in_channels=16):
super().__init__()
self.img_proj = nn.Linear(in_channels, dim)
self.time_mlp = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.blocks = nn.ModuleList([MMDiTBlock(dim, num_heads) for _ in range(num_blocks)])
self.out_norm = nn.LayerNorm(dim, elementwise_affine=False)
self.out_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))
self.out_proj = nn.Linear(dim, in_channels)
def forward(self, x_t, t_emb, text_emb):
B, H, W, C = x_t.shape
img_feat = self.img_proj(rearrange(x_t, "B H W C -> B (H W) C"))
t_cond = self.time_mlp(t_emb)
for block in self.blocks:
img_feat, text_emb = block(img_feat, text_emb, t_cond)
og, ob = self.out_mod(t_cond).chunk(2, dim=-1)
out = self.out_proj(self.out_norm(img_feat) * og.unsqueeze(1) + ob.unsqueeze(1))
return rearrange(out, "B (H W) C -> B H W C", H=H, W=W)用 Diffusers 跑 SD3
pip install diffusers transformers accelerate torch safetensorsimport torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16,
).to("cuda")
prompt = "一只坐在樱花树下的柴犬,午后阳光,日系清新风格,8K 高清"
image = pipe(prompt, height=1024, width=1024,
num_inference_steps=28, guidance_scale=7.0).images[0]
image.save("mmdit_output.png")2024–2026 主要进展
| 方向 | 代表 |
|---|---|
| 缩放 | SD3 系列 800M → 8B;SD3.5 优化块结构与长文本理解 |
| 注意力优化 | FlashAttention、窗口注意力、线性注意力,复杂度从 O(N²) → O(N) |
| 量化 | GPTQ / AWQ / HQQ 4-bit,显存 ↓75%,8B 模型可在 16 GB 卡上运行 |
| 蒸馏 | SD3 Tiny ≈ 300M 参数,手机端实时生成 |
| 多模态 | 双路 → 多路,文本/图像/音频/视频/3D 联合生成 |
| 可控性 | ControlNet、IP-Adapter 与 MMDiT 深度融合 |
| 实时 | Rectified Flow 1 步采样,prompt 即时反馈 |
| 微调 | LoRA / QLoRA / DoRA 适配,消费级卡可微调 |