变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。
回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。
其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。
这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。
一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。
KL divergence 的公式如下
为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1
所以标准的变分自动编码器如下
我们需要训练 mu (均值)和 logvar(方差)成正态分布,也就是让均值接近0,方差接近1。我们还需要降低解码后的图片和原图片的loss,所以我们的最终loss是 均值方差与正态分布的loss 和 解码后与编码前的loss。
reconstruction_function = nn.MSELoss(size_average=False) #MSE损失函数 def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images #解码后的图片 x: origin images #原图片 mu: latent mean #均值 logvar: latent log variance #方差 """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return MSE + KLD //图片loss和正态loss optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) #解码后的tensor需要转换成图片的形式 def to_img(x): ''' 定义一个函数将最后的结果转换回图片 ''' x = 0.5 * (x + 1.) x = x.clamp(0, 1) x = x.view(x.shape[0], 1, 28, 28) return x
变分解码器的模型如下
class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) # mean均值 self.fc22 = nn.Linear(400, 20) # var方差 self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = torch.FloatTensor(std.size()).normal_() if torch.cuda.is_available(): eps = Variable(eps.cuda()) else: eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) return F.tanh(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) # 编码 z = self.reparametrize(mu, logvar) # 重新参数化成正态分布 return self.decode(z), mu, logvar # 解码,同时输出均值方差
原始链接:https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_GAN/vae.ipynb