MMDiT(Multi-Modal Diffusion Transformer)是 Stable Diffusion 3 的骨干架构。在 DiT 基础上引入双路并行 Transformer + 双向交叉注意力,让图像与文本模态在每一层深度互动,配合 Rectified Flow 实现训练稳定与极少步采样。

演进脉络

阶段关键节点
2017Transformer 诞生
2020DDPM 奠定现代扩散
2022Meta 提出 DiT:纯 Transformer + adaLN-Zero,打破 U-Net 缩放瓶颈
2024Stability AI 在 SD3 落地 MMDiT:双路 + 双向 cross-attention + Rectified Flow
2024-2026扩展到视频/3D/多模态联合生成;轻量化与极少步采样工业落地

Rectified Flow(前置基础)

整流流是 ODE 流生成模型,比 DDPM 训练更稳、采样更短。

线性插值路径

为真实样本,

学习速度场 拟合 ,损失:

采样:从 出发求解 ODE 反向积分到

优势:训练稳定,采样可压缩到 1–4 步,无需复杂噪声调度器。

adaLN-Zero(前置基础)

DiT 用全局条件(时间步 + 类别/文本嵌入)生成 LayerNorm 的缩放、偏移、门控参数:

调制后的 Transformer 块:

Zero 初始化:调制层权重置零 → 训练初期为恒等映射 → 深层稳定收敛。

MMDiT 核心:双路 + 双向 cross

输入输出

形状 / 说明
图像 latent ,VAE 16× 下采样,C=16(SD3)
时间步嵌入 ,正弦位置编码
文本嵌入 ,T5 / CLIP 编码器
输出 形状,速度场

单个双路 Transformer 块(4 步)

1) 全局条件调制:图像/文本各自独立的 adaLN 参数

文本支路同理。

2) 模态内自注意力(图像/文本各做一次 MHSA)

3) 模态间双向交叉注意力(MMDiT 的核心创新)

图像 → 文本:Q=img,  KV=text   ← 图像感知文本
文本 → 图像:Q=text, KV=img    ← 文本感知图像

不同于传统单向 cross,MMDiT 让两个模态深度互相感知。

4) FFN + 残差(图像/文本各一次,沿用 adaLN 调制)

输出层

堆叠 N 个双路块后,对图像支路再做一次 adaLN + 投影,恢复

六大优势

  1. 缩放性:模型从 800M → 8B 参数线性提升,无明显饱和
  2. 多模态对齐:双向 cross 解决”提示词不匹配/多物体遗漏/长文本理解”
  3. 训练/采样稳定:adaLN-Zero + Rectified Flow,最少 1 步采样
  4. 可扩展模态:新增模态只需增加支路与 cross 模块,已扩到深度图、草图、音频、视频、3D
  5. 长距离建模:全局注意力解决肢体结构、复杂布局、全局光影一致性
  6. 微调友好:兼容 LoRA / QLoRA / DoRA 等低秩方案

应用场景

类型代表
文生图SD3 系列,4K 原生生成
图像编辑Inpainting / Outpainting / 风格迁移
可控生成文本 + 深度图/草图/姿态/语义图,像素级精控
文生视频Stable Video 3,时空注意力
3D 生成3D latent 输入,文本驱动 3D 模型
多模态联合文/图/音/视频联合生成,通用生成式 AI 底座
垂直领域医疗影像、遥感、工业缺陷检测

极简 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
 
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.out_proj = nn.Linear(dim, dim)
 
    def forward(self, q, kv=None):
        kv = q if kv is None else kv
        q = rearrange(self.q_proj(q),  "B N (H D) -> B H N D", H=self.num_heads)
        k = rearrange(self.k_proj(kv), "B N (H D) -> B H N D", H=self.num_heads)
        v = rearrange(self.v_proj(kv), "B N (H D) -> B H N D", H=self.num_heads)
        attn = F.softmax(q @ k.transpose(-2, -1) * self.scale, dim=-1)
        out = rearrange(attn @ v, "B H N D -> B N (H D)")
        return self.out_proj(out)
 
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4
        self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(),
                                 nn.Linear(hidden_dim, dim))
 
    def forward(self, x):
        return self.net(x)
 
class MMDiTBlock(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 8 * dim))
        self.img_self_attn  = Attention(dim, num_heads)
        self.text_self_attn = Attention(dim, num_heads)
        self.img_cross_attn  = Attention(dim, num_heads)
        self.text_cross_attn = Attention(dim, num_heads)
        self.img_ffn  = FeedForward(dim)
        self.text_ffn = FeedForward(dim)
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
 
    def forward(self, img_feat, text_feat, t_cond):
        params = self.modulation(t_cond).chunk(8, dim=-1)
        ig, ib, ia, iaf, tg, tb, ta, taf = params
 
        # 1) 模态内自注意力
        img_feat  = img_feat  + ia.unsqueeze(1)  * self.img_self_attn (self.norm(img_feat)  * ig.unsqueeze(1) + ib.unsqueeze(1))
        text_feat = text_feat + ta.unsqueeze(1)  * self.text_self_attn(self.norm(text_feat) * tg.unsqueeze(1) + tb.unsqueeze(1))
 
        # 2) 双向交叉注意力
        img_feat  = img_feat  + self.img_cross_attn (q=self.norm(img_feat),  kv=text_feat)
        text_feat = text_feat + self.text_cross_attn(q=self.norm(text_feat), kv=img_feat)
 
        # 3) FFN
        img_feat  = img_feat  + iaf.unsqueeze(1) * self.img_ffn (self.norm(img_feat)  * ig.unsqueeze(1) + ib.unsqueeze(1))
        text_feat = text_feat + taf.unsqueeze(1) * self.text_ffn(self.norm(text_feat) * tg.unsqueeze(1) + tb.unsqueeze(1))
        return img_feat, text_feat
 
class MMDiT(nn.Module):
    def __init__(self, dim=1024, num_heads=16, num_blocks=24, in_channels=16):
        super().__init__()
        self.img_proj = nn.Linear(in_channels, dim)
        self.time_mlp = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.blocks   = nn.ModuleList([MMDiTBlock(dim, num_heads) for _ in range(num_blocks)])
        self.out_norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.out_mod  = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim))
        self.out_proj = nn.Linear(dim, in_channels)
 
    def forward(self, x_t, t_emb, text_emb):
        B, H, W, C = x_t.shape
        img_feat = self.img_proj(rearrange(x_t, "B H W C -> B (H W) C"))
        t_cond = self.time_mlp(t_emb)
        for block in self.blocks:
            img_feat, text_emb = block(img_feat, text_emb, t_cond)
        og, ob = self.out_mod(t_cond).chunk(2, dim=-1)
        out = self.out_proj(self.out_norm(img_feat) * og.unsqueeze(1) + ob.unsqueeze(1))
        return rearrange(out, "B (H W) C -> B H W C", H=H, W=W)

用 Diffusers 跑 SD3

pip install diffusers transformers accelerate torch safetensors
import torch
from diffusers import StableDiffusion3Pipeline
 
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16,
).to("cuda")
 
prompt = "一只坐在樱花树下的柴犬,午后阳光,日系清新风格,8K 高清"
image = pipe(prompt, height=1024, width=1024,
             num_inference_steps=28, guidance_scale=7.0).images[0]
image.save("mmdit_output.png")

2024–2026 主要进展

方向代表
缩放SD3 系列 800M → 8B;SD3.5 优化块结构与长文本理解
注意力优化FlashAttention、窗口注意力、线性注意力,复杂度从 O(N²) → O(N)
量化GPTQ / AWQ / HQQ 4-bit,显存 ↓75%,8B 模型可在 16 GB 卡上运行
蒸馏SD3 Tiny ≈ 300M 参数,手机端实时生成
多模态双路 → 多路,文本/图像/音频/视频/3D 联合生成
可控性ControlNet、IP-Adapter 与 MMDiT 深度融合
实时Rectified Flow 1 步采样,prompt 即时反馈
微调LoRA / QLoRA / DoRA 适配,消费级卡可微调

学习资源