张芷铭的个人博客

变分法通过优化泛函极值近似复杂分布,是 VAE、DDPM 等生成模型的理论基础。

变分原理的基本概念

变分(Variational)源自变分法,研究在给定约束条件下极小化或极大化泛函的方法。泛函是将函数映射到实数的规则:

$$J[f] = \int_a^b F(x, f(x), f’(x)) , dx$$

变分推理

变分推理通过引入可调分布族,优化找到最优分布来近似目标后验分布。

核心思想

贝叶斯推理中后验分布:

$$p(\mathbf{z} | \mathbf{x}) = \frac{p(\mathbf{x} | \mathbf{z}) p(\mathbf{z})}{p(\mathbf{x})}$$

直接计算困难,因为边际似然 $p(\mathbf{x})$ 需要对所有潜变量积分。变分推理使用简单分布 $q(\mathbf{z})$ 逼近真实后验,最小化 KL 散度:

$$\text{KL}(q(\mathbf{z}) | p(\mathbf{z} | \mathbf{x})) = \mathbb{E}_{q(\mathbf{z})} \left[ \log \frac{q(\mathbf{z})}{p(\mathbf{z} | \mathbf{x})} \right]$$

变分下界(ELBO)

$$\log p(\mathbf{x}) \geq \mathbb{E}_{q(\mathbf{z})} [\log p(\mathbf{x}, \mathbf{z})] - \text{KL}(q(\mathbf{z}) | p(\mathbf{z}))$$

优化目标:

$$\mathcal{L}{\text{VI}} = \mathbb{E}{q(\mathbf{z})} [\log p(\mathbf{x}, \mathbf{z})] - \mathbb{E}_{q(\mathbf{z})} [\log q(\mathbf{z})]$$

变分自动编码器(VAE)

VAE 是变分推理的典型应用,通过编码器-解码器结构学习潜在表示:

$$\mathcal{L}{\text{VAE}} = \mathbb{E}{q(\mathbf{z} | \mathbf{x})} [\log p(\mathbf{x} | \mathbf{z})] - \text{KL}(q(\mathbf{z} | \mathbf{x}) | p(\mathbf{z}))$$

含义
第一项重构误差(采样近似)
第二项KL 散度,约束近似分布与先验分布的差距

代码示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # 均值
        self.fc22 = nn.Linear(400, latent_dim)  # 对数方差
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x.view(-1, 784)))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KL

应用总结

领域应用
变分法数学优化,寻找泛函极值
变分推理贝叶斯推理中近似后验分布
VAE生成模型中学习潜在空间表示
DDPM扩散模型中推导训练目标

Comments