变分法通过优化泛函极值近似复杂分布,是 VAE、DDPM 等生成模型的理论基础。
变分原理的基本概念
变分(Variational)源自变分法,研究在给定约束条件下极小化或极大化泛函的方法。泛函是将函数映射到实数的规则:
变分推理
变分推理通过引入可调分布族,优化找到最优分布来近似目标后验分布。
核心思想
贝叶斯推理中后验分布:
直接计算困难,因为边际似然 需要对所有潜变量积分。变分推理使用简单分布 逼近真实后验,最小化 KL 散度:
变分下界(ELBO)
优化目标:
变分自动编码器(VAE)
VAE 是变分推理的典型应用,通过编码器-解码器结构学习潜在表示:
| 项 | 含义 |
|---|---|
| 第一项 | 重构误差(采样近似) |
| 第二项 | KL 散度,约束近似分布与先验分布的差距 |
代码示例
import torch
import torch.nn as nn
class VAE(nn.Module):
def __init__(self, latent_dim=20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, latent_dim) # 均值
self.fc22 = nn.Linear(400, latent_dim) # 对数方差
self.fc3 = nn.Linear(latent_dim, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = torch.relu(self.fc1(x.view(-1, 784)))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = torch.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KL应用总结
| 领域 | 应用 |
|---|---|
| 变分法 | 数学优化,寻找泛函极值 |
| 变分推理 | 贝叶斯推理中近似后验分布 |
| VAE | 生成模型中学习潜在空间表示 |
| DDPM | 扩散模型中推导训练目标 |