0%

KL散度在VAE中的代码实现

能否使用除Gauss分布外的其他分布?为什么一定要选取Gauss分布呢?

这篇笔记用于记录在VAE计算过程中,KL散度计算的实现方式,以及背后的原理分析。

参考资料: * KL散度的推导过程

在VAE中,为了解决随机采样无法求梯度的问题,假设潜变量空间的参数满足Gauss分布(一维情况): $$\begin{align} \mathcal{N}(\mu, \sigma)=\frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\mu)^2}{2 \sigma^2}} \end{align}$$ 然后利用均值和方差进行抽样。这种技巧称为重参数化(reparameterization)。

在应用重参数化技巧后,需要计算两个分布的差距,因此引入KL散度计算: $$\begin{align} \text{KL}\left(p_1(x) \| p_2(x)\right)=\int_x p_1(x) \ln \frac{p_1(x)}{p_2(x)} d x \end{align}$$

针对两个一维高斯分布: p1 = 𝒩1(μ1, σ1), p2 = 𝒩2(μ2, σ2), 可以计算他们的KL散度如下: $$ \begin{align} \text{KL}\left(p_1 \| p_2\right) & =\int_x p_1(x) \ln \frac{p_1(x)}{q_1(x)} d x \\ & =\int_x p_1(x) \left[\ln p_1(x) - \ln q_1(x)\right] d x \\ & =\int_x p_1(x) \left[\ln\left( \frac{1}{\sqrt{2 \pi \sigma_1^2}} e^{-\frac{(x-\mu_1)^2}{2 \sigma_1^2}}\right) - \ln\left( \frac{1}{\sqrt{2 \pi \sigma_2^2}} e^{-\frac{(x-\mu_2)^2}{2 \sigma_2^2}}\right)\right] d x \\ & =\int_x p_1(x) \left[-\ln \sigma_1 -\frac{(x-\mu_1)^2}{2 \sigma_1^2} + \ln\sigma_2 +\frac{(x-\mu_2)^2}{2 \sigma_2^2}\right] d x \\ & =\int_x p_1(x) \left[\ln \frac{\sigma_2}{\sigma_1} -\frac{(x-\mu_1)^2}{2 \sigma_1^2} +\frac{(x-\mu_2)^2}{2 \sigma_2^2}\right] d x \\ & =\log \frac{\sigma_2}{\sigma_1}+\underbrace{\int_x p_1(x) \frac{\left(x-\mu_2\right)^2}{2 \sigma_2^2} d x}_{\mathrm{B}}-\frac{1}{2}\\ \end{align} $$

关注较为复杂的第二项, 即下标 B 这一项。接下来要用的并不是带入 p1(x), 而是较为巧妙的使用 x − μ2 = (x − μ1) + (μ1 − μ2), 重新使用常数、方差等性质。 $$ \begin{align} B & =\frac{1}{2 \sigma_2^2}\int_x p_1(x)\left(x-\mu_2\right)^2 d x \\ & =\frac{1}{2 \sigma_2^2}\int_x p_1(x)\left[\left(x-\mu_1\right)+\left(\mu_1-\mu_2\right)\right]^2 d x \\ & =\frac{1}{2 \sigma_2^2}\int_x p_1(x)\left(x-\mu_1\right)^2 d x+\frac{2\left(\mu_1-\mu_2\right) }{2 \sigma_2^2}\int_x p_1(x)\left(x-\mu_1\right) d x+\left(\mu_1-\mu_2\right)^2 \\ & =\frac{1}{2 \sigma_2^2}\left[\sigma_1^2+0+\left(\mu_1-\mu_2\right)^2\right] \\ & =\frac{1}{2 \sigma_2^2}\left[\sigma_1^2+\left(\mu_1-\mu_2\right)^2\right] \end{align} $$

综合以上结果, 我们有: $$ \text{KL}\left(p_1 \| p_2\right)=\log \frac{\sigma_2}{\sigma_1}+\frac{1}{2 \sigma_2^2}\left(\sigma_1^2+\left(\mu_1-\mu_2\right)^2\right)-\frac{1}{2} $$

接下来, 回到VAE中, 由于我们将自由变量从标准正态分布中采样, 即 p2 = 𝒩2(0, 1); $$ \text{KL}\left(p_1|| p_2\right)=-\frac{1}{2} \times\left[2 \log \sigma_1+1-\sigma_1^2-\mu_1^2\right] $$

大多数的VAE代码中间学习并不是 μσ, 而是 μlogvar = log σ2,因此在代码中需要进行变换 。

代码如下:

1
KL = -0.5*torch.sum(logvar + 1 - mu.pow() - logvar.exp())