VAE详解(附pytorch)

        VAE:变分自编码器,遵循encoder-decoder结构,但是encoder的结果是正态分布的均值和方差(其实也是一般的参数,只是我们赋予了它们均值和方差的意义)

        编码是从不同的事物中提取不同的特征,关注下面的蓝色与绿色曲线,并假设它们分别是数字1与2的编码,如果他们的方差都是0,那么图像就是一条在0点处的蓝色竖线和在-2处的绿色竖线,那么编码的十分完美且清晰,但是我们这是生成模型,可以有一些多样性,那么我们的编码就可以模糊一些,就真正如下图所示,如果在绿色与蓝色的交界处进行采样,那么生成的图片可以长得既像1又既像2

        如果我们只在最终的输出端用一个loss来进行约束,那么模型中间的编码就会让方差趋向于0,从而使模型失去多样性,所以这里我们要引入额外的约束,所以这里使用KL散度来衡量两个分布之间的距离。

一般正态分布和标准正态分布的推导细节这里就不再赘述了,网上都可以找到。

VAE详解(附pytorch)_第1张图片

VAE详解(附pytorch)_第2张图片

         中间编码是先验概率,也就是每一个输入我们都会把它编码为一个正态分布函数,但是一个正太分布函数可以唯一由均值和方差确定,所以我们为了方便起见,直接从标准正态分布中采样,然后通过均值和方差来恢复为原来的分布。退一步将如果直接从各自的分布中进行采样,那么采样这个操作是不可导的,我们就没办法训练了。

 VAE详解(附pytorch)_第3张图片

VAE详解(附pytorch)_第4张图片

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize([0.5], [0.5]),
                               ])),
    batch_size=8192, shuffle=True,num_workers=5)

class Sample(nn.Module):
    def __init__(self):
        super(Sample, self).__init__()
    def forward(self,z_mean,z_log_var):
        epsilon=torch.randn(z_mean.shape)
        epsilon=epsilon.to('cuda:1')
        return z_mean+(z_log_var/2).exp()*epsilon

class VaeEncoder(nn.Module):
    def __init__(self):
        super(VaeEncoder, self).__init__()
        self.Dense=nn.Linear(original_dim,intermediate_dim)
        self.z_mean=nn.Linear(intermediate_dim,latent_dim)
        self.z_log_var=nn.Linear(intermediate_dim,latent_dim)
        self.sample=Sample()
    def forward(self,x):
        o=torch.nn.functional.relu(self.Dense(x))
        z_mean=self.z_mean(o)
        z_log_var=self.z_log_var(o)
        o=self.sample(z_mean,z_log_var)
        return o,z_mean,z_log_var
class VaeDecoder(nn.Module):
    def __init__(self):
        super(VaeDecoder, self).__init__()
        self.Dense=nn.Linear(latent_dim,intermediate_dim)
        self.out=nn.Linear(intermediate_dim,original_dim)
        self.sigmoid=nn.Sigmoid()
    def forward(self,z):
        o=nn.functional.relu(self.Dense(z))
        o=self.out(o)
        return self.sigmoid(o)
class Vae(nn.Module):
    def __init__(self):
        super(Vae, self).__init__()
        self.encoder=VaeEncoder()
        self.decoder=VaeDecoder()
    def forward(self,x):
        o,mean,var=self.encoder(x)
        return self.decoder(o),mean,var

vae=Vae()
optim=torch.optim.Adam(vae.parameters(),lr=0.001)
device=torch.device('cuda:1')
vae.to(device)
for i in range(200):
    avg_loss=0
    for item in train_loader:

        item=item[0].reshape(item[0].size(0),-1)
        item=item.to(device)
        o,mean,log_var=vae(item)
        BCE_loss = nn.BCELoss(reduction='sum')
        reconstruction_loss = BCE_loss(o, item)
        KL_divergence = -0.5 * torch.sum(1+log_var-torch.exp(log_var)-mean**2)
        loss=reconstruction_loss+KL_divergence
        optim.zero_grad()
        loss.backward()
        avg_loss+=loss.item()
        optim.step()
    out=o.detach().to('cpu')
    digit = out[0].reshape(28, 28)
    plt.imshow(digit,cmap='Greys_r')
    plt.show()

    print('loss:',avg_loss/len(train_loader))

迭代100次的结果:

VAE详解(附pytorch)_第5张图片VAE详解(附pytorch)_第6张图片VAE详解(附pytorch)_第7张图片VAE详解(附pytorch)_第8张图片

你可能感兴趣的:(pytorch,深度学习,人工智能)