对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)

文章目录

  • ==加载一些必要的库==
  • ==根据自己制作的图像列表的txt文档加载自己的数据集==
  • ==定义训练迭代器和测试迭代器==
  • ==原github代码部分==
  • ==开始训练迭代==
  • ==代码运行结果==
  • ==一点小问题==

在上一篇博文也就是对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一)中,有关github的代码、注释和计算流程图已经贴出,但上述代码适用于图像识别领域的“Hello World!”——mnist数据集,后来我根据自己实验的需要对代码进行了一些改动:

加载一些必要的库

import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
# from tensorflow.examples.tutorials.mnist import input_data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import scipy.io as sio

根据自己制作的图像列表的txt文档加载自己的数据集

# 根据自己制作的图像列表的txt文档加载自己的数据集
def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        print(txt)
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split('\t')
            imgs.append((words[0], words[1]))
        print(imgs)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)

        if self.transform is not None:
            img = self.transform(img)
        else:
            img = Tensor.from_numpy(img)
        return img, label

    def __len__(self):
        return len(self.imgs)






##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

定义训练迭代器和测试迭代器

transform = transforms.Compose([transforms.Scale((150, 150)), transforms.ToTensor()])   # 转换为张量
train_txt_path = '.txt'  # 自己的数据集的位置列表txt
trainset = MyDataset(txt=train_txt_path, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=mb_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=1, shuffle=False)








##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

原github代码部分

再次附上github代码链接:https://github.com/wiseodd/generative-models

def log(x):
    return torch.log(x + 1e-8)

# Encoder: q(z|x,eps)   # 编码器
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim + eps_dim, h_dim),   # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)              # 一个全连接层
)

# Decoder: p(x|z)       # 解码器
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),             # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),             # 一个全连接层
    torch.nn.Sigmoid()
)

# Discriminator: T(X, z)   # 判别器
T = torch.nn.Sequential(
    torch.nn.Linear(X_dim + z_dim, h_dim),     # 一个全连接层
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1)                  # 一个全连接层   # 输出为一维,即一个数
)

Q.cuda()
P.cuda()
T.cuda()

def reset_grad():          # 重置梯度为0
    Q.zero_grad()
    P.zero_grad()
    T.zero_grad()

Q_solver = optim.Adam(Q.parameters(), lr=lr)    # 三个模块的优化求解器
P_solver = optim.Adam(P.parameters(), lr=lr)
T_solver = optim.Adam(T.parameters(), lr=lr)

开始训练迭代

for it in range(1000000):    # 开始迭代
    print(it)
    # X = sample_X(mb_size)   # 输入为从训练集中采样并进行类型转换后的数据
    for i, (X, _) in enumerate(train_loader):
        X = X.view(-1, 150 * 150 * 3)
        X = Variable(X)
        eps = Variable(torch.randn(mb_size, eps_dim))
        z = Variable(torch.randn(mb_size, z_dim))   # 由标准正态分布(均值为0,方差为1)中随机采样

        # Optimize VAE   # 优化变分自编码器
        # z_sample = Q(torch.cat([X, eps], 1))   # 按列拼接,需要维度一致方能行对齐
        z_sample = Q(torch.cat([X, eps], 1).cuda())   # 按列拼接,需要维度一致方能行对齐
        # X_sample = P(z_sample)
        X_sample = P(z_sample.cuda())
        # T_sample = T(torch.cat([X, z_sample], 1))
        T_sample = T(torch.cat([X, z_sample.cpu()], 1).cuda())

        disc = torch.mean(-T_sample)    # 判别器输出的负数的均值
        loglike = -nn.binary_cross_entropy(X_sample, X.cuda(), size_average=False) / mb_size
        # 交叉熵, 最小化交叉熵损失函数等价于最大化对数似然, 让重构图像尽可能接近原始输入图像

        elbo = -(disc + loglike)   # 证据下界, 常用在变分推断中

        elbo.backward()    # 证据下界反向传播,优化编码器与解码器
        Q_solver.step()
        P_solver.step()
        reset_grad()     # 重置梯度为0

        # Discriminator T(X, z)     # 对于判别器,优化判别器
        # z_sample = Q(torch.cat([X, eps], 1))  # z_sample是输入经过编码器后的输出
        z_sample = Q(torch.cat([X, eps], 1).cuda())  # z_sample是输入经过编码器后的输出
        T_q = nn.sigmoid(T(torch.cat([X, z_sample.cpu()], 1).cuda()))
        T_prior = nn.sigmoid(T(torch.cat([X, z], 1).cuda()))

        T_loss = -torch.mean(log(T_q) + log(1. - T_prior))

        T_loss.backward()
        T_solver.step()
        reset_grad()     # 重置梯度为0

        if (it + 1) % 10 == 0:
            print('Iter-{}; ELBO: {:.4}; T_loss: {:.4}'
                  .format(it, -elbo.data[0], -T_loss.data[0]))

        # Print and plot every now and then
    if (it + 1) % 10 == 0:

        for k, (X, _) in enumerate(test_loader):

            if k < 3:

                X = X.view(-1, 150 * 150 * 3)
                X = Variable(X)
                eps = Variable(torch.randn(1, eps_dim))
                z = Variable(torch.randn(1, z_dim))  # 由标准正态分布(均值为0,方差为1)中随机采样

                z_sample = Q(torch.cat([X, eps], 1).cuda())  # 按列拼接,需要维度一致方能行对齐
                X_random = P(z.cuda()).data.cpu().numpy()
                reconst = P(z_sample).data.cpu().numpy()  # 原始输入输入解码器后的输出, 取前16个作为示范
                X = X.numpy()

                reconst1 = reconst.reshape(3, 150 * 150)
                reconst2 = reconst1[0, :]
                reconst3 = np.zeros(((3, 150, 150)))
                reconst3 = np.array(reconst3)
                reconst3[0, :, :] = reconst1[0, :].reshape(150, 150)
                reconst3[1, :, :] = reconst1[1, :].reshape(150, 150)
                reconst3[2, :, :] = reconst1[2, :].reshape(150, 150)
                reconst = reconst3
                # save_image(torch.from_numpy(sample), 'try/reconst_iter_' + str(it) + '_' + str(k) + '.png')

                x1 = X.reshape(3, 150 * 150)
                x2 = x1[0, :]
                x3 = np.zeros(((3, 150, 150)))
                x3 = np.array(x3)
                x3[0, :, :] = x1[0, :].reshape(150, 150)
                x3[1, :, :] = x1[1, :].reshape(150, 150)
                x3[2, :, :] = x1[2, :].reshape(150, 150)
                xx = x3
                # save_image(torch.from_numpy(xx), 'try/input_' + str(it) + '_' + str(k) + '.png')

                image_show = np.concatenate((xx, reconst), axis=2)
                image_show = image_show[np.newaxis, :, :, :]
                save_image(torch.from_numpy(image_show), 'try/compare_' + str(it) + '_' + str(k) + '.png')

                random1 = X_random.reshape(3, 150 * 150)
                random2 = random1[0, :]
                random3 = np.zeros(((3, 150, 150)))
                random3 = np.array(random3)
                random3[0, :, :] = random1[0, :].reshape(150, 150)
                random3[1, :, :] = random1[1, :].reshape(150, 150)
                random3[2, :, :] = random1[2, :].reshape(150, 150)
                random_image = random3

                save_image(torch.from_numpy(random_image), 'try/random_' + str(it) + '_' + str(k) + '.png')






##########################################
# 转载或使用请附上本文链接:https://blog.csdn.net/S20144144/article/details/
# 纵心似水

代码运行结果

代码运行结果如下所示:
compare_0_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第1张图片
compare_0_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第2张图片
compare_0_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第3张图片
compare_9_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第4张图片
compare_9_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第5张图片
compare_9_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第6张图片
compare_19_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第7张图片
compare_19_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第8张图片
compare_19_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第9张图片
。。。。。。
random_0_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第10张图片
random_0_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第11张图片
random_0_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第12张图片
random_9_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第13张图片
random_9_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第14张图片
random_9_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第15张图片
random_19_0.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第16张图片
random_19_1.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第17张图片
random_19_2.png:
对抗变分贝叶斯:变分自编码器与生成对抗网络的统一(二)_第18张图片
。。。。。。
可以看出,随着迭代次数增加,图片生成质量也会越高。

一点小问题

代码中对隐含层的z_sample位置再次进行了高斯随机采样以生成新的人脸图像,对于对抗变分自编码器来说这是否合理?(AVB的隐含层没有显式的概率分布,其为一个黑箱模型)对于AVB有没有更好的生成新的人脸图像的采样方法呢?
这篇对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (二)就先写到这里,若有疏漏、不恰当或者错误的地方还请及时指出。另外,代码还需进一步的优化,若你有更好的修改方式或想法,请不吝赐教。
这里附上上一篇博文的链接:对抗变分自编码器——AVB(Adversarial Variational Bayes: Unifying Variational Autoencoders and … ) (一):https://blog.csdn.net/S20144144/article/details/99467235

你可能感兴趣的:(Pytorch,VAE,人脸生成)