PyTorch-GAN-master中的wgan的理解

PyTorch-GAN-master中wgan代码的解读

这是我第一次在csdn上写博客,因为前段时间开始接触到GAN,并且在一篇论文中看到wgan这样一个新奇的东西,因为从来没有写过GAN的网络,而且之前都是用pytorch在写网络,就在github上找到了PyTorch-GAN-master这个系列,里面恰好有一个wgan的网络,我在看过之后,就想自己能否尝试一下根据这个写一篇blog,当然,权当一篇笔记,供自己以后忘了的时候看看。处于当作笔记的目的,这里面写的主要是我作为一个新手想弄清楚的一些东西

PyTorch-GAN的链接

各个阶段的数据的shape

因为我也只是一个新手,在之前写网络的时候,其实最纠结的就是网络结构中的参数,和数据的shape,因为至少将这些弄清楚了,写对了,整个才能够运行,关于到Tensor的shape,主要就是网络的输入,forward,输出。

forward
输入
输出

(可能表达得不是很专业)
代码中的wgan算是一个最简单的wgan,因为就只有一个生成器和一个鉴别器,并且都是线性层,所用的数据集是mnist的手写数字数据集,也就是一张张28*28的图片,网络中默认的batch_size是64。
其中Generator网络部分代码如下:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [ nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),#z.shape=([64, 100])
            *block(128, 256),
            *block(256, 512),#[64, 512]
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),#连乘函数(1*28*28)([64, 1*28*28])
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)#img.shape=([64, 1*28*28])
        img = img.view(img.shape[0], *img_shape)#shape[0]第一维的长度1([64, 1, 28, 28])
        return img

Discriminator的网络部分代码如下

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)#先将img展平为[64, 1*28*28]
        validity = self.model(img_flat)#[64,1]
        return validity

在GAN训练的过程中,应该都知道,生成器通过一个噪声生成一个目标(本文中就是一张图片),这个目标尽可能去迷惑鉴别器,而鉴别器又尽可能得去判断真实的和生成的目标。
所在,本例中的输入噪声是z

z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))#(均值, 标准差, size)([64,100])

z就是一个[64, 100]的tensor,里面的数据的分布是0为均值,1为标准差的正态分布的随机数据,G的输入是z,输出generator(z),D的输入为generator(z),输出为D(gen(z)),各个阶段的数据形式我都在注释中写出来了。

训练的具体过程

网络训练部分的代码如下:

batches_done = 0
for epoch in range(opt.n_epochs):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))#imgs.shape=([64, 1, 28, 28])

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input

        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))#(均值, 标准差, size)([64,100])
        # Generate a batch of images
        fake_imgs = generator(z).detach()#不求G中的参数
        #generator(z).shape=([64, 1, 28, 28])
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))#discriminator(gen_imags).shape=([64, 1])
            loss_G.backward()
            optimizer_G.step()
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs,
                                                            batches_done % len(dataloader), len(dataloader),
                                                            loss_D.item(), loss_G.item()))

        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)
        batches_done += 1

训练过程描述如下:
先初始D的梯度
利用噪声生成G的输入Z
G根据Z生成fake_imgs
loss_D=-E+ED(f)
D的梯度反向传播并优化
将D中的参数限制(-c, c)的范围内
每隔opt.n_critic进行优化G
 将G的梯度归零
 生成图像
 loss_G=-ED(G(z))
 G的梯度反向传播并优化网络
每隔一定的步数输出生成的图形

一些可以着重理解的点

1.WGAN的几个特征
生成器和判别器的loss不取log
每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行(本例中用的RMSProp)
关于WGAN的知识可以参考令人拍案叫绝的Wasserstein GAN

2.G的参数不需要反向传播。也就是

fake_imgs = generator(z).detach()

生成的结果

我截取了几张生成结果,如下:
第50000步生成的图片
PyTorch-GAN-master中的wgan的理解_第1张图片
第100000步生成的图片

PyTorch-GAN-master中的wgan的理解_第2张图片第150000步生成的图片
PyTorch-GAN-master中的wgan的理解_第3张图片

以上仅是个人看法,如果有错误的地方,望大家指正!

你可能感兴趣的:(PyTorch-GAN-master中的wgan的理解)