46.变分自编码器 VAE

变分自编码器(Variational Auto-Encoders,VAE)

VAE是生成数据用的,GAN(对抗神经网络)也是生成数据用的

在上一节的自编码器也可以生成数据,但是它对中间encode的变量是由要求的,VAE可以理解为是自编码器的decode部分的改良,它对中间变量有限制,方法为限制中间(隐含)变量的KL散度

  • KL散度越小,两个变量的分布就越相近,我们一般搞一个预设的分布变量,比如正太分布,通过降低KL散度,可以是中间变量也趋近于正太分布

KL散度和我们后面使用的损失函数有关,像U的那个参数是分布的平均值,不像U的那个参数是分布的标准差

下面我们使用mnist实现变分自编码器,变分自编码器大致有三步

  • 编码过程:通过图像得到均值与标准差
  • 重参数化:通过均值与标准差得到隐含变量
  • 解码过程:通过隐含变量得到重构的图像

目录

1  导入库

2  数据预处理

3  创建神经网络

4  定义损失与优化器

5  定义训练步骤

6  定义训练过程

7  训练

8  使用 VAE


1  导入库

2  数据预处理

先读取数据,然后归一化,之后降为2为,然后改为float32类型,之后创建数据集,再之后设置随机与批次

46.变分自编码器 VAE_第1张图片

3  创建神经网络

对应一开始写的三个步骤,我们的神经网络应该这样创建

46.变分自编码器 VAE_第2张图片

  • Dense层1与Dense层4是输入数据后将数据扁平化的

在层的神经元个数中,除了最后的784是定的(28*28),其余都可以自定,之后写上面提到的三个步骤,这里的std我们认为是标准差的对数的2倍,下面的true_std是真实的标准差

46.变分自编码器 VAE_第3张图片

创建自定义模型后,将模型实例化

4  定义损失与优化器

我们一共有两种损失函数,一种是预测的图与原图像不像,这个损失我们命名为BCE_loss,另一种的分布情况,就是上面提到的公式,我们定义为KLD_loss,我们最终使用的loss是这两个loss的和

当图像越相近的时候BCE_loss会越低,但有可能造成KLD_loss越高,当图像越符合正太分布时,图像就有可能越不像,就会导致BCE_loss越高。BCE_loss与KLD_loss为对抗关系,机器会在两种loss中寻求平衡,以达到总loss最低的效果,我们也可以在返回值中加权,比如我想让图像更像,那么我可以 return 0.8 * BCE_loss + 0.2 * KLD_loss,这里的权重值是超参数,超参数的和加起来不一定为1

当然我们定义权重的时候最好提前看一下两个loss的数量级

46.变分自编码器 VAE_第4张图片

一个是零点几,另一个是-425,这个加在一起就太偏向KLD了,我们现在令其均衡些,如果我们定义为0.001那么有可能会出现loss值为负值的情况,为了避免这个情况,我们定义KLD_loss的权重为0.0001

46.变分自编码器 VAE_第5张图片

这样它变成0.4,这还差不多

下面定义优化器

5  定义训练步骤

创建一个指标求每个epoch的loss的均值

之后定义训练步骤

46.变分自编码器 VAE_第6张图片

  • model()是执行model的call方法,model.predict(),是执行model的predict方法

6  定义训练过程

跟之前的自定义变量内容相似,由于是生成图像,所以我们训练每一个epoch都展示一次生成的图像

46.变分自编码器 VAE_第7张图片

7  训练

第一个epoch的loss是0.323,从图像上来看感觉是写了一堆8

46.变分自编码器 VAE_第8张图片

epoch的loss是0.210,隐约能看见几个1了

46.变分自编码器 VAE_第9张图片

最后一个epoch是0.204,效果是这样的

46.变分自编码器 VAE_第10张图片

能看出个大概,loss还有下降的趋势,如果多训练几轮可能效果更好

46.变分自编码器 VAE_第11张图片

保存的时候直接使用model.save会报错,所以我们在这里仅保存权重,之后使用test的时候再读权重

8  使用 VAE

导入库后创建模型,之后读取模型权重

46.变分自编码器 VAE_第12张图片

之后创建均值向量与标准差向量,np.linspace的第一个参数是最小值,第二个参数是最大值,第三个参数是个数

使用随机出来的两个值创建隐藏向量,之后对其进行解码,然后reshape成图像的样子

之后将生成的图像展示出来

46.变分自编码器 VAE_第13张图片

46.变分自编码器 VAE_第14张图片

你可能感兴趣的:(tensorflow笔记,深度学习,神经网络,机器学习)