自编码器keras实现数值输入_变分自编码器+要点综述+代码实现+生成图片

1.VAE的结构

变分自编码器(Variational Autoencoders)是由Diederik Kingma和Max Welling在2014年提出来的。

1.1 网络结构

VAE的基本结构如下图所示,来自《Hands On ML》: Figure 15-11:Variational autoencoder (left), and an instance going through it (right)。

自编码器keras实现数值输入_变分自编码器+要点综述+代码实现+生成图片_第1张图片

上述的Hidden1,Hidden2可以使密集层或者卷基层。

模型结构上来看Variational Autoencoders 和Autoencoders 的区别主要在Coding Layer:VAE希望在模型输入和输出尽量相同的时候,同时编码层拟合一个正态分布(例如标准正态分布)。VAE希望编码层最终的输出向量仿佛是从一个正态分布抽样出来的,同时还和输入很像。

其实现方法是VAE编码层拟合了一个正态分布的均值向量μ和方差向量(log(σ^2)),然后用二者计算(这个计算过程又叫reparameterize)得到一个带有随机性的向量。再对这个向量进行解码,希望解码之后的Outputs尽量和Inputs相同。

1.2 作为生成模型的VAE

当训练完一个VAE之后,随机生成若干正态分布向量(根据训练时候Coding Layer预设的正态分布函数,例如标准正态分布)。用训练好的VAE的解码部分进行计算。就可以得到若干“很像”原始数据的数据或者图片。

1.3 更多思想的细节

关于VAE的更多细节,请参考一篇写的很好的知文 变分自编码器VAE:原来是这么一回事 | 附开源代码 。

2.VAE的损失函数

【说明:VAE在代码实现的过程中,个人感觉最难理解的就是损失函数,其他部分和常规的神经网络差不多】

作为自编码器,损失函数肯定要考虑Inputs和Outputs的相似性,所以需要有Reconstruction Loss,这个比较好理解和实现——就是衡量输入和输出的差异。

变分自编码器由于其Coding Layer的特殊性,其在Coding Layer需要“拟合”一个标准正态分布。所以需要衡量这个“拟合的程度”。所以VAE的损失函数还需要一个Latent Loss。

VAE最终的损失函数 Loss=Reconstruction_Loss+Latent_Loss

2.1 重构损失-Reconstruction Loss

衡量输出和输出之间误差的方法应该很多。参考了几个资料,基本都是交叉熵来作为损失函数(直接用输出输出的均方误差似乎收敛的很慢)。 用tensorflow代码表示重构误差如下(其中X是样本输入,logits是VAE模型输出):

reconstruction_loss 

2.2 KL-Latent Loss

VAE的Coding Layer希望拟合一个标准正态分布,因此衡量这种“拟合的程度”最好理解的一种方法就是KL散度——KL散度给出两个分布的差异的度量。 KL散度越大,两个分布差别越大。

KL散度损失函数的推导公式,请参考变分自编码器KL散度的Latent Loss推导 。

2.3 ELBO-Latent Loss

【说明:下述的公式中,x可以理解为输入VAE的样本;z为隐变量,可以理解为VAE的encoder的输出;q(z)表示隐变量的目标分布;p(z|x)表示拟合出的分布;p(x)为样本数据的实际分布】

ELBO是指evidence lower bound。

自编码器keras实现数值输入_变分自编码器+要点综述+代码实现+生成图片_第2张图片

自编码器keras实现数值输入_变分自编码器+要点综述+代码实现+生成图片_第3张图片

2.4 更多理论细节

请参考 Notes on Variational Autoencoders 。

3. 代码实现

3.1 CVAE训练和生产图片-全部代码

下述代码实现,主要来自于 Tensorflow-Convolutional Variational Autoencoder的教程。教程中用的ELBO latent loss。自己在实验过程中,添加了KL Latent Loss的函数,看上去实验结果二者差距不是很大。

这个代码比较方便的地方是,如果需要把卷积层修改成其他类型的层(例如密集层),可以在类里面直接修改,很方便。

import 

3.2 epochs=50两种损失函数生成的图片

两次试验损失函数包含相同的 Reconstruction Loss。

自编码器keras实现数值输入_变分自编码器+要点综述+代码实现+生成图片_第4张图片

你可能感兴趣的:(自编码器keras实现数值输入)