张芷铭的个人博客

Flow Matching 通过学习速度场实现噪声分布到数据分布的平滑转换,兼具高质量样本与快速采样优势。

核心概念

Flow Matching 是基于连续归一化流(CNF) 的生成模型训练框架。通过学习时间相关的向量场,将简单先验分布转换为复杂目标数据分布,实现无模拟训练

与其他模型对比

模型优势局限
GAN快速采样训练不稳定,模式坍塌
VAE稳定训练生成模糊
扩散模型高质量采样慢(100+步)
Flow Matching高质量 + 快采样实现复杂

数学原理

常微分方程

$$\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x)), \quad \phi_0(x) = x$$

条件流匹配损失

$$\mathcal{L}{\text{CFM}}(\theta) = \mathbb{E}{t,q(x_1),p_t(x|x_1)} | v_t(x) - u_t(x|x_1) |^2$$

代码实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class FlowModel(nn.Module):
    def __init__(self, input_dim=2, time_embed_dim=64):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )
        self.net = nn.Sequential(
            nn.Linear(input_dim + time_embed_dim, 128),
            nn.SiLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, x, t):
        t_embed = self.time_embed(t)
        xt = torch.cat([x, t_embed], dim=-1)
        return self.net(xt)

def flow_matching_loss(model, x0, x1, t):
    xt = (1 - t) * x0 + t * x1
    v_target = x1 - x0
    v_pred = model(xt, t)
    return torch.mean((v_pred - v_target)**2)

前沿变体

方法特点
最优传输 Flow MatchingWasserstein 距离设计最短路径
随机插值增强生成多样性
整流 Flow Matching梯度裁剪稳定训练
等变 Flow MatchingSE(3)等变性,分子生成

应用场景

  • 高分辨率图像生成(>1024×1024)
  • 跨模态条件生成
  • 视频预测
  • 3D 内容生成
  • 蛋白质折叠预测

扩展阅读

Comments