目录
引言(Introduction)
生成对抗网络(Generative Adversarial Networks)
什么是GAN?(What is a GAN?)
什么是DCGAN?(What is a DCGAN?)
输入(Inputs)
数据(Data)
实现(Implementation)
权重初始化(Weight Initialization)
生成器(Generator)
判别器(Discriminator)
损失函数和优化器(Loss Functions and Optimizers)
训练(Training)
下一步(Where to Go Next)
原文链接:https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
本教程将通过一个示例介绍DCGAN(Deep Convolutional Generative Adversarial Networks)。我们将训练一个生成对抗网络(GAN),在展示许多名人的真实照片后产生新的名人。这里的代码实现来自 pytorch/examples,本文档对代码实现进行透彻的解释,并阐明此模型如何以及为什么有效果。但别担心,理解GANs不需要有先验知识,但它可能需要你花一些时间来研究幕后到底发生了什么。另外,因为时间的缘故,有一个或两个GPU也会有帮助。让我们开始吧。
GANs是一个教学DL(Deep Learning)模型的框架,使得DL模型可以捕获训练数据的分布,这样我们就可以在相同的数据分布中生成新的数据。GANs是Goodfellow 在2014年发明的,并在 Generative Adversarial Nets论文中首次提出。它由两个不同的模型组成,一个生成器和一个判别器。生成器的目标是生成类似于训练图片的图片,判别器的目标是,输入一张图片,判断输入的图片是真图片还是生成器产生的假图片。在训练过程中,生成器不断的生成更好的假图片试图骗过判别器,而判别器则在努力成为更好的鉴别者,正确的对真假图片进行分类。这个游戏的平衡点就是生成器产生的图片就好像是从训练图片中取出的一样,判别器总是有50%的置信度鉴别生成器的图片是真或是假。
现在,让我们定义一些在整个教程中使用的符号,从判别器(discriminator)开始。设x表示图像数据。D(x)表示判别器,它的输出是x来自训练数据而不是生成器的概率(标量)。这里,我们处理的是CHW(channel,height,width)为3*64*64大小的图像。直观的说,当x来自训练数据时D(x)的值应该是高的,当x来自生成器时D(x)的值应该是低的。你也可以把D(x)看作是传统的二元分类器。
对于生成器(generator )的符号,设z是从标准正态分布采样的隐向量(此处的隐没有什么特别高深晦涩难懂的意思,就像前馈神经网络的隐藏层一样,表示没有物理含义的变量或空间,一般不具备可解释性),G(z)表示将隐向量z映射到数据空间的生成函数。G的目标是估算训练数据的分布(pdata),以便从估计的分布(pg)中生成假样本。
所以,D(G(z))是生成器G的输出是真实图片的概率(标量)。正如 Goodfellow的论文中所描述的:D和G在玩一个极大极小博弈:D试图最大化它能正确分类真赝品的概率 (logD(x)),而G试图最小化D预测其输出是假的概率 (log(1−D(G(x))))。从论文中可以看出,GAN的损失函数为:
理论上,这个极大极小博弈的解决方案是pg=pdata,判别器随机猜测输入图片是真是假。然而,GANs的收敛理论仍在积极研究中,而现实中的模型通常不能做到收敛。
DCGAN是上述DAN的直接扩展,不同之处在于它在判别器和生成器中分别使用了卷积和卷积转置层。它是由Radford 等人在 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks论文中首先提出的。其中的判别器由convolution层,batch norm层,和LeakyReLU激活函数组成。输入是一个3*64*64的图片数据,输出是一个概率(标量),即输入来自真实数据的分布。其中的生成器由convolutional-transpose层,batch norm层,和ReLU激活函数组成。输入是一个隐向量——z,来自标准正态分布,输出是一个3*64*64的GRB图片。卷积转置层可以将隐向量转换成图像的形状。在论文中,作者还提供了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,这些将在接下来的部分中进行解释。
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
输出:
Random Seed: 999
让我们为接下来的运行定义一些输入:
# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
在本教程中,我们将使用 Celeb-A Faces dataset 数据集,该数据集可以在链接站点或 Google Drive中下载。数据集下载之后是一个名为img_align_celeba.zip的文件。当你下载完成之后,创建一个celeba目录并将zip文件解压到这个目录。然后,将上一节提到的dataroot 输入的值设置为我们刚刚创建的celeba目录。生成的目录结构应为:
/path/to/celeba
-> img_align_celeba
-> 188242.jpg
-> 173822.jpg
-> 284702.jpg
-> 537394.jpg
...
这是非常重要的一步,因为我们将使用ImageFolder这个数据集类,它要求在这个数据集的根目录下必须要有子目录。现在,我们可以创建数据集,创建dataloader,设置device,最后可视化一些训练数据。
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
输入参数和数据集都准备好了,现在可以进入实现环节了。我们将会从权重的初始化策略开始,然后详细讨论生成器,判别器,损失函数和训练过程。
在DCGAN的论文中,作者指定所有模型的初始化权重是一个均值为0,标准差为0.02的正态分布。weights_init函数的输入是一个初始化的模型,然后按此标准重新初始化模型的卷积层、卷积转置层和BN层的权重。模型初始化后应立即应用此函数。(这个文章中,我有的时候用的权重,有时候用参数,这两个名词是等价的)
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
生成器G, 用于将隐向量 (z)映射到数据空间。 由于我们的数据是图片,也就是通过隐向量z生成一张与训练图片大小相同的RGB图片 (比如 3x64x64). 在实践中,这是通过一系列的ConvTranspose2d,BatchNorm2d,ReLU完成的。 生成器的输出,通过tanh激活函数把数据映射到[−1,1]。值得注意的是,在卷积转置层之后紧跟BN层,这是DCGAN论文的重要贡献。这些层(即BN层)有助于训练过程中梯度的流动。DCGAN论文中的生成器如下图所示。
注意,我们在输入(Inputs)小节设置的参数 (nz, ngf, and nc) 影响着生成器G的架构。 nz 是隐向量z的长度, ngf 为生成器的特征图大小,nc 是输出图片(若为RGB图像,则设置为3)的通道数。 生成器的代码如下:
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(in_channels=ngf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
"""
上卷积层可理解为是卷积层的逆运算。
拿最后一个上卷积层举例。若卷积的输入是(nc) x 64 x 64时,
经过Hout=(Hin+2*Padding-kernel_size)/stride+1=(64+2*1-4)/2+1=32,输出为(out_channels) x 32 x 32
此处上卷积层为卷积层的输入输出的倒置:
即输入通道数为out_channels,输出通道数为3;输入图片大小为(out_channels) x 32 x 32,输出图片的大小为(nc) x 64 x 64
"""
def forward(self, input):
return self.main(input)
现在,我们可以实例化生成器,并应用weights_init方法。打印并查看生成器的结构。
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
# Print the model
print(netG)
输出如下:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
如前所述,判别器D是一个二分类网络,它将图片作为输入,输出其为真的标量概率。这里,D的输入是一个3*64*64的图片,通过一系列的 Conv2d, BatchNorm2d,和 LeakyReLU 层对其进行处理,最后通过Sigmoid 激活函数输出最终概率。如有必要,你可以使用更多层对其扩展。DCGAN 论文提到使用跨步卷积而不是池化进行降采样是一个很好的实践,因为它可以让网络自己学习池化方法。BatchNorm2d层和LeakyReLU层也促进了梯度的健康流动,这对生成器G和判别器D的学习过程都是至关重要的。
判别器代码
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
现在,我们可以实例化判别器,并应用weights_init方法。打印并查看判别器的结构。
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)
# Print the model
print(netD)
输出如下:
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
有了生成器D和判别器G,我们可以为其指定损失函数和优化器来进行学习。这里将使用Binary Cross Entropy损失函数 (BCELoss)。其在PyTorch中的定义为:
注意这个损失函数需要你提供两个log组件 (比如 log(D(x))和log(1−D(G(z))))。我们可以指定BCE的哪个部分使用输入y标签。这将会在接下来的训练小节中讲到,但是明白我们可以仅仅通过改变y标签来指定使用哪个log部分是非常重要的(比如GT标签)。
接下来,我们定义真实标签为1,假标签为0。这些标签用来计算生成器D和判别器G的损失,这也是原始GAN论文的惯例。最后,我们将设置两个独立的优化器,一个用于生成器G,另一个判别器D。如DCGAN 论文所述,两个Adam优化器学习率都为0.0002,Beta1都为0.5。为了记录生成器的学习过程,我们将会生成一批符合高斯分布的固定的隐向量(即fixed_noise)。在训练过程中,我们将周期性地把固定噪声作为生成器G的输入,通过输出看到由噪声生成的图像。
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
最后,我们已经定义了GAN网络的所有结构,可以开始训练它了。请注意,训练GAN有点像一种艺术形式,因为不正确的超参数会导致模式崩溃,却不会提示超参数错误的信息。这里,我们将遵循Goodfellow’s论文的算法1,同时遵循 ganhacks中的一些最佳实践。也就是说,我们将会“为真假数据构造不同的mini-batches数据”,同时调整判别器G的目标函数以最大化logD(G(z))。训练分为两个部分。第一部分更新判别器,第二部分更新生成器。
第一部分——训练判别器(Part 1 - Train the Discriminator)
回想一下,判别器的训练目的是最大化输入正确分类的概率。从Goodfellow的角度来看,我们希望“通过随机梯度的变化来更新鉴别器”。实际上,我们想要最大化log(D(x))+log(1−D(G(z)))。为了区别mini-batch,ganhacks建议分两步计算。第一步,我们将会构造一个来自训练数据的真图片batch,作为判别器D的输入,计算其损失loss(log(D(x)),调用backward方法计算梯度。第二步,我们将会构造一个来自生成器G的假图片batch,作为判别器D的输入,计算其损失loss(log(1−D(G(z))),调用backward方法累计梯度。最后,调用判别器D优化器的step方法更新一次模型(即判别器D)的参数。
第二部分——训练生成器(Part 2 - Train the Generator)
如原论文所述,我们希望通过最小化log(1−D(G(z)))训练生成器G来创造更好的假图片。作为解决方案,我们希望最大化log(D(G(z)))。通过以下方法来实现这一点:使用判别器D来分类在第一部分G的输出图片,计算损失函数的时候用真实标签(记做GT),调用backward方法更新生成器G的梯度,最后调用生成器G优化器的step方法更新一次模型(即生成器G)的参数。使用真实标签作为GT来计算损失函数看起来有悖常理,但是这允许我们可以使用BCELoss的log(x)部分而不是log(1−x)部分,这正是我们想要的。
最后,我们将做一些统计报告。以展示每个迭代完成之后我们的固定噪声通过生成器G产生的图片信息。训练过程中统计数据报告如下:
注意: 这一步可能会运行时间久一些。这取决于你跑了多少Epochs和你的数据集中有多少数据。
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
输出:
Starting Training Loop...
[0/5][0/1583] Loss_D: 1.7834 Loss_G: 5.0952 D(x): 0.5564 D(G(z)): 0.5963 / 0.0094
[0/5][50/1583] Loss_D: 0.2582 Loss_G: 28.5604 D(x): 0.8865 D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.9311 Loss_G: 13.3240 D(x): 0.9443 D(G(z)): 0.4966 / 0.0000
[0/5][150/1583] Loss_D: 0.7385 Loss_G: 8.8132 D(x): 0.9581 D(G(z)): 0.4625 / 0.0004
[0/5][200/1583] Loss_D: 0.4796 Loss_G: 6.5862 D(x): 0.9888 D(G(z)): 0.3271 / 0.0047
[0/5][250/1583] Loss_D: 0.7410 Loss_G: 5.4159 D(x): 0.8282 D(G(z)): 0.3274 / 0.0082
[0/5][300/1583] Loss_D: 0.4622 Loss_G: 3.7107 D(x): 0.7776 D(G(z)): 0.1251 / 0.0375
[0/5][350/1583] Loss_D: 1.0642 Loss_G: 6.3149 D(x): 0.9374 D(G(z)): 0.5391 / 0.0061
[0/5][400/1583] Loss_D: 0.3848 Loss_G: 6.3376 D(x): 0.9153 D(G(z)): 0.2209 / 0.0036
[0/5][450/1583] Loss_D: 0.2790 Loss_G: 4.3376 D(x): 0.8896 D(G(z)): 0.1256 / 0.0217
[0/5][500/1583] Loss_D: 1.2478 Loss_G: 8.1121 D(x): 0.9361 D(G(z)): 0.5578 / 0.0016
[0/5][550/1583] Loss_D: 0.3393 Loss_G: 4.0673 D(x): 0.8257 D(G(z)): 0.0496 / 0.0323
[0/5][600/1583] Loss_D: 0.8083 Loss_G: 2.5396 D(x): 0.6232 D(G(z)): 0.0484 / 0.1265
[0/5][650/1583] Loss_D: 0.3682 Loss_G: 4.3142 D(x): 0.8227 D(G(z)): 0.1114 / 0.0217
[0/5][700/1583] Loss_D: 0.4788 Loss_G: 6.2379 D(x): 0.8594 D(G(z)): 0.2307 / 0.0037
[0/5][750/1583] Loss_D: 0.4767 Loss_G: 5.3962 D(x): 0.8935 D(G(z)): 0.2463 / 0.0092
[0/5][800/1583] Loss_D: 0.8085 Loss_G: 2.3573 D(x): 0.5934 D(G(z)): 0.0769 / 0.1357
[0/5][850/1583] Loss_D: 0.3595 Loss_G: 3.9025 D(x): 0.7769 D(G(z)): 0.0563 / 0.0381
[0/5][900/1583] Loss_D: 0.3235 Loss_G: 4.7795 D(x): 0.9224 D(G(z)): 0.1785 / 0.0163
[0/5][950/1583] Loss_D: 0.3426 Loss_G: 3.1228 D(x): 0.8257 D(G(z)): 0.0847 / 0.0795
[0/5][1000/1583] Loss_D: 0.6667 Loss_G: 7.3167 D(x): 0.9556 D(G(z)): 0.3751 / 0.0019
[0/5][1050/1583] Loss_D: 0.2840 Loss_G: 5.0387 D(x): 0.9268 D(G(z)): 0.1642 / 0.0143
[0/5][1100/1583] Loss_D: 0.4534 Loss_G: 3.8780 D(x): 0.7535 D(G(z)): 0.0697 / 0.0391
[0/5][1150/1583] Loss_D: 0.5040 Loss_G: 2.9283 D(x): 0.7452 D(G(z)): 0.1167 / 0.0839
[0/5][1200/1583] Loss_D: 0.6478 Loss_G: 4.0913 D(x): 0.6595 D(G(z)): 0.0263 / 0.0358
[0/5][1250/1583] Loss_D: 1.2299 Loss_G: 7.8236 D(x): 0.9850 D(G(z)): 0.5941 / 0.0013
[0/5][1300/1583] Loss_D: 0.3228 Loss_G: 4.9211 D(x): 0.8882 D(G(z)): 0.1488 / 0.0140
[0/5][1350/1583] Loss_D: 0.4208 Loss_G: 4.1520 D(x): 0.8254 D(G(z)): 0.1638 / 0.0260
[0/5][1400/1583] Loss_D: 0.5751 Loss_G: 3.9585 D(x): 0.7692 D(G(z)): 0.1902 / 0.0329
[0/5][1450/1583] Loss_D: 1.6244 Loss_G: 0.5350 D(x): 0.3037 D(G(z)): 0.0159 / 0.6617
[0/5][1500/1583] Loss_D: 0.3676 Loss_G: 3.2653 D(x): 0.8076 D(G(z)): 0.0825 / 0.0710
[0/5][1550/1583] Loss_D: 0.2759 Loss_G: 4.4156 D(x): 0.9010 D(G(z)): 0.1370 / 0.0178
[1/5][0/1583] Loss_D: 1.0879 Loss_G: 7.8641 D(x): 0.8737 D(G(z)): 0.5376 / 0.0008
[1/5][50/1583] Loss_D: 0.2761 Loss_G: 4.4716 D(x): 0.9008 D(G(z)): 0.1267 / 0.0231
[1/5][100/1583] Loss_D: 0.3438 Loss_G: 4.0343 D(x): 0.8389 D(G(z)): 0.1162 / 0.0308
[1/5][150/1583] Loss_D: 0.4937 Loss_G: 4.8593 D(x): 0.7951 D(G(z)): 0.1819 / 0.0162
[1/5][200/1583] Loss_D: 0.3973 Loss_G: 3.2078 D(x): 0.8671 D(G(z)): 0.1916 / 0.0587
[1/5][250/1583] Loss_D: 0.4521 Loss_G: 4.5155 D(x): 0.9006 D(G(z)): 0.2441 / 0.0222
[1/5][300/1583] Loss_D: 0.4423 Loss_G: 5.3907 D(x): 0.8635 D(G(z)): 0.2039 / 0.0125
[1/5][350/1583] Loss_D: 0.6447 Loss_G: 2.5607 D(x): 0.6177 D(G(z)): 0.0195 / 0.1284
[1/5][400/1583] Loss_D: 0.4079 Loss_G: 4.2563 D(x): 0.8621 D(G(z)): 0.1949 / 0.0268
[1/5][450/1583] Loss_D: 0.9649 Loss_G: 8.0302 D(x): 0.9727 D(G(z)): 0.5302 / 0.0010
[1/5][500/1583] Loss_D: 0.7693 Loss_G: 5.9895 D(x): 0.9070 D(G(z)): 0.4331 / 0.0053
[1/5][550/1583] Loss_D: 0.4522 Loss_G: 2.6169 D(x): 0.7328 D(G(z)): 0.0634 / 0.1113
[1/5][600/1583] Loss_D: 0.4039 Loss_G: 3.4861 D(x): 0.8436 D(G(z)): 0.1738 / 0.0494
[1/5][650/1583] Loss_D: 0.4434 Loss_G: 3.0261 D(x): 0.7756 D(G(z)): 0.1299 / 0.0777
[1/5][700/1583] Loss_D: 1.5401 Loss_G: 8.3636 D(x): 0.9705 D(G(z)): 0.7050 / 0.0011
[1/5][750/1583] Loss_D: 0.3899 Loss_G: 4.3379 D(x): 0.7379 D(G(z)): 0.0231 / 0.0248
[1/5][800/1583] Loss_D: 0.9547 Loss_G: 5.6122 D(x): 0.9520 D(G(z)): 0.5318 / 0.0074
[1/5][850/1583] Loss_D: 0.3714 Loss_G: 3.2116 D(x): 0.7770 D(G(z)): 0.0752 / 0.0700
[1/5][900/1583] Loss_D: 0.2717 Loss_G: 4.0063 D(x): 0.8673 D(G(z)): 0.1058 / 0.0272
[1/5][950/1583] Loss_D: 0.2652 Loss_G: 3.7649 D(x): 0.8381 D(G(z)): 0.0540 / 0.0361
[1/5][1000/1583] Loss_D: 0.9463 Loss_G: 1.6266 D(x): 0.5189 D(G(z)): 0.0913 / 0.2722
[1/5][1050/1583] Loss_D: 0.7117 Loss_G: 3.7363 D(x): 0.8544 D(G(z)): 0.3578 / 0.0397
[1/5][1100/1583] Loss_D: 0.5164 Loss_G: 4.0939 D(x): 0.8865 D(G(z)): 0.2904 / 0.0252
[1/5][1150/1583] Loss_D: 0.3745 Loss_G: 3.1891 D(x): 0.8262 D(G(z)): 0.1358 / 0.0645
[1/5][1200/1583] Loss_D: 0.4583 Loss_G: 2.9545 D(x): 0.7866 D(G(z)): 0.1453 / 0.0778
[1/5][1250/1583] Loss_D: 0.5870 Loss_G: 4.4096 D(x): 0.9473 D(G(z)): 0.3706 / 0.0208
[1/5][1300/1583] Loss_D: 0.5159 Loss_G: 4.1076 D(x): 0.8640 D(G(z)): 0.2738 / 0.0240
[1/5][1350/1583] Loss_D: 0.6005 Loss_G: 1.8590 D(x): 0.6283 D(G(z)): 0.0418 / 0.2032
[1/5][1400/1583] Loss_D: 0.3646 Loss_G: 3.4323 D(x): 0.7712 D(G(z)): 0.0653 / 0.0534
[1/5][1450/1583] Loss_D: 0.6245 Loss_G: 2.2462 D(x): 0.6515 D(G(z)): 0.0905 / 0.1514
[1/5][1500/1583] Loss_D: 0.6055 Loss_G: 1.7674 D(x): 0.7026 D(G(z)): 0.1682 / 0.2169
[1/5][1550/1583] Loss_D: 0.5181 Loss_G: 3.2728 D(x): 0.7926 D(G(z)): 0.2048 / 0.0549
[2/5][0/1583] Loss_D: 0.9580 Loss_G: 5.1154 D(x): 0.9605 D(G(z)): 0.5535 / 0.0105
[2/5][50/1583] Loss_D: 0.9947 Loss_G: 1.7223 D(x): 0.4860 D(G(z)): 0.0563 / 0.2477
[2/5][100/1583] Loss_D: 0.7023 Loss_G: 4.1781 D(x): 0.9083 D(G(z)): 0.4116 / 0.0239
[2/5][150/1583] Loss_D: 0.3496 Loss_G: 2.7264 D(x): 0.8871 D(G(z)): 0.1795 / 0.0982
[2/5][200/1583] Loss_D: 0.6805 Loss_G: 3.8157 D(x): 0.8900 D(G(z)): 0.3851 / 0.0312
[2/5][250/1583] Loss_D: 0.6193 Loss_G: 3.8180 D(x): 0.8557 D(G(z)): 0.3286 / 0.0303
[2/5][300/1583] Loss_D: 0.6480 Loss_G: 1.4683 D(x): 0.6157 D(G(z)): 0.0640 / 0.2844
[2/5][350/1583] Loss_D: 0.7498 Loss_G: 4.1299 D(x): 0.8922 D(G(z)): 0.4244 / 0.0256
[2/5][400/1583] Loss_D: 0.7603 Loss_G: 4.2291 D(x): 0.9512 D(G(z)): 0.4604 / 0.0213
[2/5][450/1583] Loss_D: 0.4833 Loss_G: 4.0068 D(x): 0.9348 D(G(z)): 0.3095 / 0.0257
[2/5][500/1583] Loss_D: 1.2311 Loss_G: 0.7107 D(x): 0.3949 D(G(z)): 0.0496 / 0.5440
[2/5][550/1583] Loss_D: 0.9657 Loss_G: 1.5119 D(x): 0.4513 D(G(z)): 0.0338 / 0.2821
[2/5][600/1583] Loss_D: 0.5351 Loss_G: 3.4546 D(x): 0.8889 D(G(z)): 0.3018 / 0.0449
[2/5][650/1583] Loss_D: 0.8761 Loss_G: 1.2051 D(x): 0.5292 D(G(z)): 0.1193 / 0.3583
[2/5][700/1583] Loss_D: 1.0206 Loss_G: 4.5741 D(x): 0.8599 D(G(z)): 0.5140 / 0.0159
[2/5][750/1583] Loss_D: 1.0886 Loss_G: 5.4749 D(x): 0.9770 D(G(z)): 0.6093 / 0.0067
[2/5][800/1583] Loss_D: 0.6539 Loss_G: 3.5203 D(x): 0.9074 D(G(z)): 0.3962 / 0.0390
[2/5][850/1583] Loss_D: 0.8633 Loss_G: 1.0995 D(x): 0.5701 D(G(z)): 0.1401 / 0.3842
[2/5][900/1583] Loss_D: 0.3703 Loss_G: 2.2482 D(x): 0.8183 D(G(z)): 0.1302 / 0.1329
[2/5][950/1583] Loss_D: 0.6592 Loss_G: 1.6081 D(x): 0.6040 D(G(z)): 0.0818 / 0.2523
[2/5][1000/1583] Loss_D: 0.7449 Loss_G: 1.0548 D(x): 0.5975 D(G(z)): 0.1375 / 0.4085
[2/5][1050/1583] Loss_D: 0.5783 Loss_G: 2.3644 D(x): 0.6435 D(G(z)): 0.0531 / 0.1357
[2/5][1100/1583] Loss_D: 0.6123 Loss_G: 2.2695 D(x): 0.7269 D(G(z)): 0.2083 / 0.1343
[2/5][1150/1583] Loss_D: 0.6263 Loss_G: 1.8714 D(x): 0.6661 D(G(z)): 0.1407 / 0.1914
[2/5][1200/1583] Loss_D: 0.4233 Loss_G: 3.0119 D(x): 0.8533 D(G(z)): 0.2039 / 0.0692
[2/5][1250/1583] Loss_D: 0.8826 Loss_G: 3.3618 D(x): 0.7851 D(G(z)): 0.3971 / 0.0502
[2/5][1300/1583] Loss_D: 0.6201 Loss_G: 2.1584 D(x): 0.6418 D(G(z)): 0.0977 / 0.1536
[2/5][1350/1583] Loss_D: 0.9558 Loss_G: 3.8876 D(x): 0.8561 D(G(z)): 0.5001 / 0.0302
[2/5][1400/1583] Loss_D: 0.4369 Loss_G: 2.3479 D(x): 0.7959 D(G(z)): 0.1588 / 0.1214
[2/5][1450/1583] Loss_D: 0.5086 Loss_G: 2.1034 D(x): 0.6758 D(G(z)): 0.0586 / 0.1575
[2/5][1500/1583] Loss_D: 0.6513 Loss_G: 3.5801 D(x): 0.8535 D(G(z)): 0.3429 / 0.0455
[2/5][1550/1583] Loss_D: 0.6975 Loss_G: 2.5560 D(x): 0.7379 D(G(z)): 0.2784 / 0.1031
[3/5][0/1583] Loss_D: 2.2846 Loss_G: 1.7977 D(x): 0.1771 D(G(z)): 0.0111 / 0.2394
[3/5][50/1583] Loss_D: 1.6111 Loss_G: 5.7904 D(x): 0.9581 D(G(z)): 0.7350 / 0.0063
[3/5][100/1583] Loss_D: 0.8553 Loss_G: 1.0540 D(x): 0.5229 D(G(z)): 0.1020 / 0.3945
[3/5][150/1583] Loss_D: 0.7402 Loss_G: 2.6338 D(x): 0.7668 D(G(z)): 0.3277 / 0.0959
[3/5][200/1583] Loss_D: 0.9278 Loss_G: 2.9689 D(x): 0.8913 D(G(z)): 0.4787 / 0.0769
[3/5][250/1583] Loss_D: 2.6573 Loss_G: 6.4810 D(x): 0.9684 D(G(z)): 0.8799 / 0.0035
[3/5][300/1583] Loss_D: 0.5435 Loss_G: 1.9416 D(x): 0.7118 D(G(z)): 0.1454 / 0.1801
[3/5][350/1583] Loss_D: 1.2350 Loss_G: 4.6877 D(x): 0.9595 D(G(z)): 0.6444 / 0.0147
[3/5][400/1583] Loss_D: 0.9264 Loss_G: 0.9139 D(x): 0.4825 D(G(z)): 0.0715 / 0.4526
[3/5][450/1583] Loss_D: 0.8967 Loss_G: 4.4258 D(x): 0.9155 D(G(z)): 0.5074 / 0.0174
[3/5][500/1583] Loss_D: 0.6874 Loss_G: 2.4529 D(x): 0.7775 D(G(z)): 0.3171 / 0.1097
[3/5][550/1583] Loss_D: 0.5821 Loss_G: 3.0756 D(x): 0.8681 D(G(z)): 0.3161 / 0.0609
[3/5][600/1583] Loss_D: 0.7164 Loss_G: 1.5045 D(x): 0.5652 D(G(z)): 0.0428 / 0.2868
[3/5][650/1583] Loss_D: 0.6290 Loss_G: 2.1863 D(x): 0.7952 D(G(z)): 0.2829 / 0.1442
[3/5][700/1583] Loss_D: 0.6270 Loss_G: 1.2824 D(x): 0.6481 D(G(z)): 0.1184 / 0.3234
[3/5][750/1583] Loss_D: 0.7011 Loss_G: 1.3549 D(x): 0.5861 D(G(z)): 0.0926 / 0.3017
[3/5][800/1583] Loss_D: 0.6912 Loss_G: 1.4927 D(x): 0.5919 D(G(z)): 0.0741 / 0.2728
[3/5][850/1583] Loss_D: 0.6385 Loss_G: 2.9333 D(x): 0.8418 D(G(z)): 0.3338 / 0.0723
[3/5][900/1583] Loss_D: 0.7835 Loss_G: 4.4475 D(x): 0.9290 D(G(z)): 0.4703 / 0.0151
[3/5][950/1583] Loss_D: 0.6294 Loss_G: 2.3463 D(x): 0.7388 D(G(z)): 0.2414 / 0.1202
[3/5][1000/1583] Loss_D: 0.6288 Loss_G: 1.5448 D(x): 0.6575 D(G(z)): 0.1389 / 0.2581
[3/5][1050/1583] Loss_D: 0.6292 Loss_G: 3.4867 D(x): 0.8741 D(G(z)): 0.3549 / 0.0433
[3/5][1100/1583] Loss_D: 0.7644 Loss_G: 1.7661 D(x): 0.5457 D(G(z)): 0.0408 / 0.2076
[3/5][1150/1583] Loss_D: 0.4918 Loss_G: 3.1858 D(x): 0.8576 D(G(z)): 0.2563 / 0.0527
[3/5][1200/1583] Loss_D: 1.1773 Loss_G: 4.5200 D(x): 0.8192 D(G(z)): 0.5536 / 0.0183
[3/5][1250/1583] Loss_D: 0.6889 Loss_G: 1.8073 D(x): 0.6909 D(G(z)): 0.2230 / 0.1969
[3/5][1300/1583] Loss_D: 0.9721 Loss_G: 1.0578 D(x): 0.4541 D(G(z)): 0.0570 / 0.4080
[3/5][1350/1583] Loss_D: 0.5301 Loss_G: 2.3562 D(x): 0.7453 D(G(z)): 0.1670 / 0.1222
[3/5][1400/1583] Loss_D: 0.5464 Loss_G: 2.5304 D(x): 0.8018 D(G(z)): 0.2438 / 0.1020
[3/5][1450/1583] Loss_D: 0.5987 Loss_G: 2.2034 D(x): 0.6195 D(G(z)): 0.0601 / 0.1477
[3/5][1500/1583] Loss_D: 1.4470 Loss_G: 4.2791 D(x): 0.9006 D(G(z)): 0.6537 / 0.0221
[3/5][1550/1583] Loss_D: 0.7917 Loss_G: 3.3235 D(x): 0.8287 D(G(z)): 0.4002 / 0.0489
[4/5][0/1583] Loss_D: 0.7682 Loss_G: 1.2445 D(x): 0.5371 D(G(z)): 0.0538 / 0.3386
[4/5][50/1583] Loss_D: 0.9274 Loss_G: 0.9439 D(x): 0.4905 D(G(z)): 0.1004 / 0.4476
[4/5][100/1583] Loss_D: 0.9571 Loss_G: 0.7391 D(x): 0.4619 D(G(z)): 0.0511 / 0.5431
[4/5][150/1583] Loss_D: 1.4795 Loss_G: 0.7522 D(x): 0.3092 D(G(z)): 0.0387 / 0.5307
[4/5][200/1583] Loss_D: 0.5203 Loss_G: 1.8662 D(x): 0.7279 D(G(z)): 0.1425 / 0.1895
[4/5][250/1583] Loss_D: 0.8140 Loss_G: 1.9120 D(x): 0.5155 D(G(z)): 0.0606 / 0.1939
[4/5][300/1583] Loss_D: 0.5813 Loss_G: 2.5807 D(x): 0.7674 D(G(z)): 0.2255 / 0.1008
[4/5][350/1583] Loss_D: 0.5209 Loss_G: 2.8571 D(x): 0.8125 D(G(z)): 0.2389 / 0.0743
[4/5][400/1583] Loss_D: 0.4505 Loss_G: 2.7965 D(x): 0.8221 D(G(z)): 0.2014 / 0.0805
[4/5][450/1583] Loss_D: 0.4919 Loss_G: 2.4360 D(x): 0.8148 D(G(z)): 0.2163 / 0.1100
[4/5][500/1583] Loss_D: 0.5861 Loss_G: 1.8476 D(x): 0.7139 D(G(z)): 0.1733 / 0.1968
[4/5][550/1583] Loss_D: 0.3823 Loss_G: 2.7134 D(x): 0.8286 D(G(z)): 0.1591 / 0.0833
[4/5][600/1583] Loss_D: 0.8388 Loss_G: 4.0517 D(x): 0.9135 D(G(z)): 0.4704 / 0.0238
[4/5][650/1583] Loss_D: 1.1851 Loss_G: 3.8484 D(x): 0.9364 D(G(z)): 0.6310 / 0.0301
[4/5][700/1583] Loss_D: 0.6797 Loss_G: 1.6355 D(x): 0.6011 D(G(z)): 0.0880 / 0.2444
[4/5][750/1583] Loss_D: 0.6017 Loss_G: 1.8937 D(x): 0.7011 D(G(z)): 0.1684 / 0.1909
[4/5][800/1583] Loss_D: 0.6368 Loss_G: 1.7310 D(x): 0.6652 D(G(z)): 0.1495 / 0.2195
[4/5][850/1583] Loss_D: 0.7758 Loss_G: 0.8409 D(x): 0.5400 D(G(z)): 0.0775 / 0.4691
[4/5][900/1583] Loss_D: 0.5234 Loss_G: 1.7439 D(x): 0.6728 D(G(z)): 0.0839 / 0.2216
[4/5][950/1583] Loss_D: 0.6529 Loss_G: 3.4036 D(x): 0.9078 D(G(z)): 0.3899 / 0.0443
[4/5][1000/1583] Loss_D: 0.6068 Loss_G: 2.1435 D(x): 0.7773 D(G(z)): 0.2603 / 0.1434
[4/5][1050/1583] Loss_D: 0.9208 Loss_G: 2.4387 D(x): 0.7600 D(G(z)): 0.4164 / 0.1163
[4/5][1100/1583] Loss_D: 0.6253 Loss_G: 1.8932 D(x): 0.6321 D(G(z)): 0.0981 / 0.1835
[4/5][1150/1583] Loss_D: 0.6524 Loss_G: 2.7757 D(x): 0.7961 D(G(z)): 0.2996 / 0.0823
[4/5][1200/1583] Loss_D: 0.5320 Loss_G: 2.8334 D(x): 0.8048 D(G(z)): 0.2383 / 0.0781
[4/5][1250/1583] Loss_D: 0.8212 Loss_G: 1.3884 D(x): 0.5531 D(G(z)): 0.1236 / 0.3016
[4/5][1300/1583] Loss_D: 0.4568 Loss_G: 2.6822 D(x): 0.8278 D(G(z)): 0.2067 / 0.0912
[4/5][1350/1583] Loss_D: 0.6665 Loss_G: 1.3834 D(x): 0.6517 D(G(z)): 0.1532 / 0.2904
[4/5][1400/1583] Loss_D: 0.4927 Loss_G: 1.8337 D(x): 0.7101 D(G(z)): 0.1022 / 0.1965
[4/5][1450/1583] Loss_D: 2.2483 Loss_G: 0.2021 D(x): 0.1705 D(G(z)): 0.0452 / 0.8293
[4/5][1500/1583] Loss_D: 0.5997 Loss_G: 2.0054 D(x): 0.6909 D(G(z)): 0.1507 / 0.1733
[4/5][1550/1583] Loss_D: 1.0521 Loss_G: 4.8488 D(x): 0.9193 D(G(z)): 0.5659 / 0.0120
结果(Results)
最后,让我们看看我们是如何做到对抗生成的。这里,我们将会从三个不同的方面展示。首先,我们将看下D和G在训练过程中损失是如何变化的。第二,我们将会把训练过程中每个Epoch结束,固定噪声在G的输出图片可视化。第三,我们将会看到真图片和来G产生的假图片的对比。
训练过程中的对抗损失(Loss versus training iteration)
下面是生成器和判别器的损失对比图。
可视化生成器的进度(Visualization of G’s progression)
还记得我们是如何在训练时保存固定噪声在生成器G的输出的。现在,我们可以通过动画展示其训练过程。按下play按钮来开启动画。(注意,想要看动画,需在Jupyter Notebook环境下运行代码。因为 HTML(animator.to_jshtml()) 将动图在 Jupyter Notebook 里展示。)
真假图片(Real Images vs. Fake Images)
最后,让我们把真假图片并排(左侧真实图片,右侧假),对比看下。
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
我们已经到达旅程的终点了,不过这里有几个地方你可以去: