变分法通过优化泛函极值近似复杂分布,是 VAE、DDPM 等生成模型的理论基础。

变分原理的基本概念

变分(Variational)源自变分法,研究在给定约束条件下极小化或极大化泛函的方法。泛函是将函数映射到实数的规则:

变分推理

变分推理通过引入可调分布族,优化找到最优分布来近似目标后验分布。

核心思想

贝叶斯推理中后验分布:

直接计算困难,因为边际似然 需要对所有潜变量积分。变分推理使用简单分布 逼近真实后验,最小化 KL 散度:

变分下界(ELBO)

优化目标:

变分自动编码器(VAE)

VAE 是变分推理的典型应用,通过编码器-解码器结构学习潜在表示:

含义
第一项重构误差(采样近似)
第二项KL 散度,约束近似分布与先验分布的差距

代码示例

import torch
import torch.nn as nn
 
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # 均值
        self.fc22 = nn.Linear(400, latent_dim)  # 对数方差
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 784)
 
    def encode(self, x):
        h1 = torch.relu(self.fc1(x.view(-1, 784)))
        return self.fc21(h1), self.fc22(h1)
 
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
 
    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
 
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
 
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KL

应用总结

领域应用
变分法数学优化,寻找泛函极值
变分推理贝叶斯推理中近似后验分布
VAE生成模型中学习潜在空间表示
DDPM扩散模型中推导训练目标