【PyTorch】变分自编码器/Variational Autoencoder(VAE)

1 模型介绍

  • 变分自编码器(variational autoencoder,VAE)的原理介绍:VAE将经过神经网络编码后的隐藏层假设为一个标准的高斯分布,然后再从这个分布中采样一个特征,再用这个特征进行解码,期望得到与原始输入相同的结果,损失和AE几乎一样,只是增加编码推断分布与标准高斯分布的KL散度的正则项,显然增加这个正则项的目的就是防止模型退化成普通的AE,因为网络训练时为了尽量减小重构误差,必然使得方差逐渐被降到0,这样便不再会有随机采样噪声,也就变成了普通的AE。(出处:https://www.jianshu.com/p/ffd493e10751)
    【PyTorch】变分自编码器/Variational Autoencoder(VAE)_第1张图片

  • 这里使用MNIST手写数字数据集,生成重建的图像

  • 算法中一个重要的步骤是隐层中高斯分布的拟合过程,构建reparameterize函数进行参数mu和var的迭代调整

  • 整个模型的过程就是,输入图像,编码过程实现特征提取和高斯分布拟合,解码过程根据特征重建图像

2 具体代码

# variational_autoencoder
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(
    root='../../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size, 
    shuffle=True)


# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h) # two encoders for mu and var, respectively
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std 

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

3 程序输出

程序输出除了训练过程的重建Loss和KL散度,还有decode的图像,包括VAE和AE图像,这里展示epoch最大输出的两个例子,可以看出VAE生成的图像精度高出不少。
(1)VAE生成
【PyTorch】变分自编码器/Variational Autoencoder(VAE)_第2张图片
(2)AE生成
【PyTorch】变分自编码器/Variational Autoencoder(VAE)_第3张图片

Epoch[1/15], Step [100/469], Reconst Loss: 22479.4258, KL Div: 1291.5803
Epoch[1/15], Step [200/469], Reconst Loss: 18100.7598, KL Div: 1868.1145
Epoch[1/15], Step [300/469], Reconst Loss: 15728.3721, KL Div: 2310.5405
Epoch[1/15], Step [400/469], Reconst Loss: 14573.3467, KL Div: 2457.8628
Epoch[2/15], Step [100/469], Reconst Loss: 13343.7812, KL Div: 2688.8018
Epoch[2/15], Step [200/469], Reconst Loss: 12820.8320, KL Div: 2670.8801
Epoch[2/15], Step [300/469], Reconst Loss: 12177.3770, KL Div: 2857.9192
Epoch[2/15], Step [400/469], Reconst Loss: 12119.5352, KL Div: 2913.0271
Epoch[3/15], Step [100/469], Reconst Loss: 11455.6797, KL Div: 2936.3635
Epoch[3/15], Step [200/469], Reconst Loss: 11355.9854, KL Div: 2994.1992
Epoch[3/15], Step [300/469], Reconst Loss: 11565.2637, KL Div: 3027.9497
Epoch[3/15], Step [400/469], Reconst Loss: 11625.3447, KL Div: 3047.5901
Epoch[4/15], Step [100/469], Reconst Loss: 11396.5977, KL Div: 3111.6401
Epoch[4/15], Step [200/469], Reconst Loss: 11895.6436, KL Div: 3192.7432
Epoch[4/15], Step [300/469], Reconst Loss: 10787.6719, KL Div: 3120.9729
Epoch[4/15], Step [400/469], Reconst Loss: 10792.5635, KL Div: 3101.8181
Epoch[5/15], Step [100/469], Reconst Loss: 11358.7930, KL Div: 3227.3677
Epoch[5/15], Step [200/469], Reconst Loss: 10595.2998, KL Div: 3087.9536
Epoch[5/15], Step [300/469], Reconst Loss: 11012.8457, KL Div: 3079.9478
Epoch[5/15], Step [400/469], Reconst Loss: 11031.8301, KL Div: 3274.6953
Epoch[6/15], Step [100/469], Reconst Loss: 10727.4932, KL Div: 3074.3291
Epoch[6/15], Step [200/469], Reconst Loss: 10766.1553, KL Div: 3205.7544
Epoch[6/15], Step [300/469], Reconst Loss: 10917.2773, KL Div: 3153.5269
Epoch[6/15], Step [400/469], Reconst Loss: 11135.8389, KL Div: 3166.1350
Epoch[7/15], Step [100/469], Reconst Loss: 10622.8848, KL Div: 3265.8269
Epoch[7/15], Step [200/469], Reconst Loss: 10808.3926, KL Div: 3163.3755
Epoch[7/15], Step [300/469], Reconst Loss: 10255.6533, KL Div: 3148.1663
Epoch[7/15], Step [400/469], Reconst Loss: 10487.1641, KL Div: 3009.9302
Epoch[8/15], Step [100/469], Reconst Loss: 10424.5625, KL Div: 3154.1379
Epoch[8/15], Step [200/469], Reconst Loss: 10814.4883, KL Div: 3221.2366
Epoch[8/15], Step [300/469], Reconst Loss: 10307.5762, KL Div: 3272.5889
Epoch[8/15], Step [400/469], Reconst Loss: 10350.0527, KL Div: 3236.4878
Epoch[9/15], Step [100/469], Reconst Loss: 10028.5371, KL Div: 3131.3210
Epoch[9/15], Step [200/469], Reconst Loss: 10316.7715, KL Div: 3235.4766
Epoch[9/15], Step [300/469], Reconst Loss: 10969.9980, KL Div: 3212.9060
Epoch[9/15], Step [400/469], Reconst Loss: 10779.7207, KL Div: 3261.3821
Epoch[10/15], Step [100/469], Reconst Loss: 10576.8887, KL Div: 3287.4534
Epoch[10/15], Step [200/469], Reconst Loss: 10241.1055, KL Div: 3205.7427
Epoch[10/15], Step [300/469], Reconst Loss: 10066.1045, KL Div: 3187.4636
Epoch[10/15], Step [400/469], Reconst Loss: 10090.7051, KL Div: 3259.8884
Epoch[11/15], Step [100/469], Reconst Loss: 10129.8330, KL Div: 3116.4709
Epoch[11/15], Step [200/469], Reconst Loss: 10742.1025, KL Div: 3320.4324
Epoch[11/15], Step [300/469], Reconst Loss: 9563.3086, KL Div: 3134.5513
Epoch[11/15], Step [400/469], Reconst Loss: 10116.4502, KL Div: 3038.9773
Epoch[12/15], Step [100/469], Reconst Loss: 10564.5547, KL Div: 3168.8032
Epoch[12/15], Step [200/469], Reconst Loss: 10309.9707, KL Div: 3233.3108
Epoch[12/15], Step [300/469], Reconst Loss: 10618.7500, KL Div: 3306.9241
Epoch[12/15], Step [400/469], Reconst Loss: 9806.7266, KL Div: 3111.2397
Epoch[13/15], Step [100/469], Reconst Loss: 10133.7803, KL Div: 3193.5210
Epoch[13/15], Step [200/469], Reconst Loss: 9875.7354, KL Div: 3192.7917
Epoch[13/15], Step [300/469], Reconst Loss: 10075.5908, KL Div: 3207.8071
Epoch[13/15], Step [400/469], Reconst Loss: 10066.5723, KL Div: 3236.6553
Epoch[14/15], Step [100/469], Reconst Loss: 10055.6152, KL Div: 3273.6326
Epoch[14/15], Step [200/469], Reconst Loss: 9802.0449, KL Div: 3094.6011
Epoch[14/15], Step [300/469], Reconst Loss: 10179.3486, KL Div: 3199.3838
Epoch[14/15], Step [400/469], Reconst Loss: 10153.1211, KL Div: 3148.5806
Epoch[15/15], Step [100/469], Reconst Loss: 10289.4541, KL Div: 3192.3494
Epoch[15/15], Step [200/469], Reconst Loss: 10085.9668, KL Div: 3131.0674
Epoch[15/15], Step [300/469], Reconst Loss: 9941.9395, KL Div: 3189.7222
Epoch[15/15], Step [400/469], Reconst Loss: 10195.7031, KL Div: 3279.8584

你可能感兴趣的:(PyTorch)