DCGAN的PyTorch实现

DCGAN

1.什么是GAN

GAN是一个框架,让深度模型可以学习到数据的分布,从而通过数据的分布生成新的数据(服从同一分布)。

其由一个判别器和一个生成器构成,生成器负责生成“仿造数据”,判别器负责判断“仿造数据”的质量。两者一起进化,导致造假货和识别假货的两个模型G/D都能有超强的造假和识别假货的能力。

最终训练达到类似纳什均衡的平衡状态,就是分辨器已经分辨不出真假,其分别真假的成功率只有50%(和瞎猜没有区别)。

假设原数据分布为x(可以是一张真实图片等多维数据),判别器D(),随机变量Z,生成器为G()。D(x)生成一个标量代表x来自真实分布的概率。Z是一个随机噪声,G(Z)代表随机噪声Z(也称为隐空间向量)到真实分布P_data的映射。G(Z)的生成数据的概率分布记作P_G.

所以D(G(z))就是一个标量代表其生成图片是真实图片的概率
,同时D和G在玩一个你最小(G)我最大(D)的游戏。D想把自己分别真假图片x的成功率最大化

logD(x)

G想把造假图片z和真实图片x的差距最小化

log(1-D(G(x))。

总目标函数(loss function)可以写成:

image

2.什么是DCGAN

DCGAN是GAN的一个扩展,卷积网络做判别器,反卷积做生成器。

判别器通过大幅步的卷积网络、批量正则化、LeakyRelu激活函数构成。输入一个3*64 *64的图片,输出一个真假概率值。

生成器由一个反卷积网络、批量正则化、Relu激活函数构成,通过输入一个隐变量z(如标准正态分布)。同时输出一个3*64 *64的图片。

同时《 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks》的原作者还给出如何设置优化器(optimizers),如何计算损失函数,如何初始化模型weights等技巧。

初始导入代码如下:

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

3.输入设置

输入参数设置

  • dataroot - the path to the root of the dataset folder. We will talk more about the dataset in the next section
  • workers - the number of worker threads for loading the data with the DataLoader
  • batch_size - the batch size used in training. The DCGAN paper uses a batch size of 128
  • image_size - the spatial size of the images used for training. This implementation defaults to 64x64. If another size is desired, the structures of D and G must be changed.
  • nc - number of color channels in the input images. For color images this is 3
  • nz - length of latent vector
  • ngf - relates to the depth of feature maps carried through the generator
  • ndf - sets the depth of feature maps propagated through the discriminator
  • num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
  • lr - learning rate for training. As described in the DCGAN paper, this number should be 0.0002
  • beta1 - beta1 hyperparameter for Adam optimizers. As described in paper, this number should be 0.5
  • ngpu - number of GPUs available. If this is 0, code will run in CPU mode. If this number is greater than 0 it will run on that number of GPUs
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

4.数据

数据集用的是港中文的Celeb-A

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
#real_batch是一个列表
#第一个元素real_batch[0]是[128,3,64,64]的tensor,就是标准的一个batch的4D结构:128张图,3个通道,64长,64宽
#第二个元素real_batch[1]是第一个元素的标签,有128个label值全为0
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

#这个函数能让图片显示
#plt.show() 

DCGAN的PyTorch实现_第1张图片

5.实现(Implementation)

5.1 参数初始化(Weight Initialization)

w初始化为均值为0,标准差为0.02的正态分布

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

5.2 生成器(Generator)

生成器G是构造一个由向量Z(隐空间)到真实数据空间的映射(map)

  • nz=100,z输入时的长度

  • nc=3,输出时的chanel,彩色是RGB三通道

  • ngf=64,指的是生成的特征为64*64

  • 反卷积的函数为:

ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)

参数为:1.输入、2.输出、3.核函数、4.卷积核步数、5.输入边填充、6.输出边填充、7.group、8.偏置、9.膨胀

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            #输入100,输出64*8,核函数是4*4
            
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

实例化生成器,初始化参数w

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

out:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

5.3 判别器(Discriminator)

判别器D是一个二元分类器,判别输入的图片真假。通过输入图片进入一连串的卷积层中,经过卷积(Strided Convolution)、批量正则(BatchNorm)、LeakyReLu激活,最终通过Sigmoid激活函数输出一个概率选择。

以上的结构如有必要可以扩展更多的层,不过DCGAN的设计者通过实验发现调整步幅的卷积层比池化的下采样效果要好,因为通过卷积网络可以学习到自己的池化函数。同时批量正则化和leakly relu函数都可以提高梯度下降的质量,这些效果在同时训练G和D时显得更为突出。

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

构建D,并初始化w方程,并且输出模型的结构。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

out:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

5.4 损失函数&优化器(loss&optimizer)

用Pytorch自带的损失函数Binary Corss Entropy(BCELoss),其定义如下:

image

我们定义真图片real为1,假图片fake为0。同时设置两个优化器optimizer。在本例中
都是adam优化器,其学习率是0.0002且Beta1=0.5。为了保持生成学习的过程,我们从一个高斯分布中生成一个修正的批量数据。同时在训练过程中,我们定期放入修正的噪音给生成器G以提高拟合能力。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

5.5 训练

训练GAN是一种艺术,用不好超参数容易造成模式崩溃。我们通过D建立不同批次图片的真假差异,以及构建生成G函数以最大化logD(G(z))。

5.5.1 判别器D

训练判别器D的目的是让D能最大化识别真假图片的概率,通过随机梯度上升(ascending its stochastic gradient SGD)更新判别器。在实践中就是最大化log(D(x))+log(1-D(G(z)))。

以上步骤分为两步实现,第一步是从训练数据集中拿出一批真实图片作为样本,通过模型D,计算其loss即损失函数log(D(x)),然后再通过反向传播计算梯度更新损失函数。

第二步是通过生成器建立一批假样本,也通过D进行前向传播得到另一半loss值。即损失函数log(1-D(G(z))的值,同时也通过反向传播更新loss,通过1个batches的迭代更新,我们称为一次D的优化(optimizer)

5.5.2 生成器G

DCGAN的PyTorch实现_第2张图片

在GAN原始版本中G的实现是通过最小化log(1-D(G(z)))以增加更好的造假能力。值得注意的是原始版本并没有提供足够的梯度更新策略,特别在早期的训练学习过程中。作为修正,我们用最大化log(D(G(z)))来替代原先的策略。其中关键名词如下:

  • Loss_D

计算所以批次的真假图片的判别函数,即loss= log(D(x))+log(D(G(Z))

  • Loss_G

生成图片的损失函数即log(D(G(z)))

  • D(x)

输出真样本批次的为真概率,从一开始的1到理论上的拟合至0.5(即G训练好的时候)

  • D(G(z))

判别输出生成图片为真的概率,从一开始的0到理论上拟合至0.5(同为G训练好的时候)

训练时间和训练整体样本的次数(epoch),和样本的大小有关,代码如下:

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

out:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 2.0937  Loss_G: 5.2060  D(x): 0.5704    D(G(z)): 0.6680 / 0.0090
[0/5][50/1583]  Loss_D: 0.1916  Loss_G: 9.5846  D(x): 0.9472    D(G(z)): 0.0364 / 0.0002
[0/5][100/1583] Loss_D: 4.0207  Loss_G: 21.2494 D(x): 0.2445    D(G(z)): 0.0000 / 0.0000
[0/5][150/1583] Loss_D: 0.5569  Loss_G: 3.1977  D(x): 0.7294    D(G(z)): 0.0974 / 0.0609
[0/5][200/1583] Loss_D: 0.2320  Loss_G: 3.3187  D(x): 0.9009    D(G(z)): 0.0805 / 0.0659
[0/5][250/1583] Loss_D: 0.7203  Loss_G: 5.9229  D(x): 0.8500    D(G(z)): 0.3485 / 0.0062
[0/5][300/1583] Loss_D: 0.6775  Loss_G: 4.0545  D(x): 0.8330    D(G(z)): 0.3379 / 0.0353
[0/5][350/1583] Loss_D: 0.7549  Loss_G: 5.9064  D(x): 0.9227    D(G(z)): 0.4109 / 0.0084
[0/5][400/1583] Loss_D: 1.0655  Loss_G: 2.5097  D(x): 0.4933    D(G(z)): 0.0269 / 0.1286
[0/5][450/1583] Loss_D: 0.6321  Loss_G: 2.7811  D(x): 0.6453    D(G(z)): 0.0610 / 0.1026
[0/5][500/1583] Loss_D: 0.5064  Loss_G: 4.1399  D(x): 0.9475    D(G(z)): 0.3009 / 0.0350
[0/5][550/1583] Loss_D: 0.3838  Loss_G: 4.0321  D(x): 0.8221    D(G(z)): 0.1218 / 0.0331
[0/5][600/1583] Loss_D: 0.5549  Loss_G: 4.6055  D(x): 0.8230    D(G(z)): 0.2049 / 0.0171
[0/5][650/1583] Loss_D: 0.2821  Loss_G: 6.8137  D(x): 0.8276    D(G(z)): 0.0164 / 0.0027
[0/5][700/1583] Loss_D: 0.6422  Loss_G: 5.0119  D(x): 0.8267    D(G(z)): 0.2827 / 0.0146
[0/5][750/1583] Loss_D: 0.4332  Loss_G: 4.3659  D(x): 0.9239    D(G(z)): 0.2307 / 0.0291
[0/5][800/1583] Loss_D: 0.5344  Loss_G: 3.4145  D(x): 0.7208    D(G(z)): 0.0891 / 0.0744
[0/5][850/1583] Loss_D: 0.8094  Loss_G: 2.9318  D(x): 0.5903    D(G(z)): 0.0602 / 0.0979
[0/5][900/1583] Loss_D: 0.1598  Loss_G: 6.4141  D(x): 0.9228    D(G(z)): 0.0630 / 0.0046
[0/5][950/1583] Loss_D: 0.5083  Loss_G: 5.5467  D(x): 0.9226    D(G(z)): 0.2916 / 0.0112
[0/5][1000/1583]        Loss_D: 0.6738  Loss_G: 3.9958  D(x): 0.7622    D(G(z)): 0.2480 / 0.0410
[0/5][1050/1583]        Loss_D: 0.2155  Loss_G: 3.8838  D(x): 0.9092    D(G(z)): 0.0819 / 0.0432
[0/5][1100/1583]        Loss_D: 1.1708  Loss_G: 1.9610  D(x): 0.4709    D(G(z)): 0.0064 / 0.2448
[0/5][1150/1583]        Loss_D: 0.7506  Loss_G: 6.9292  D(x): 0.8797    D(G(z)): 0.3728 / 0.0019
[0/5][1200/1583]        Loss_D: 0.2133  Loss_G: 5.5082  D(x): 0.9436    D(G(z)): 0.1272 / 0.0102
[0/5][1250/1583]        Loss_D: 0.5156  Loss_G: 3.8660  D(x): 0.8073    D(G(z)): 0.1993 / 0.0357
[0/5][1300/1583]        Loss_D: 0.4848  Loss_G: 5.0770  D(x): 0.9170    D(G(z)): 0.2847 / 0.0109
[0/5][1350/1583]        Loss_D: 0.6596  Loss_G: 4.7626  D(x): 0.8414    D(G(z)): 0.3232 / 0.0145
[0/5][1400/1583]        Loss_D: 0.2799  Loss_G: 5.1604  D(x): 0.9154    D(G(z)): 0.1494 / 0.0156
[0/5][1450/1583]        Loss_D: 0.4756  Loss_G: 2.9344  D(x): 0.8164    D(G(z)): 0.1785 / 0.0955
[0/5][1500/1583]        Loss_D: 0.3904  Loss_G: 2.3755  D(x): 0.7652    D(G(z)): 0.0587 / 0.1328
[0/5][1550/1583]        Loss_D: 1.2817  Loss_G: 1.2689  D(x): 0.3769    D(G(z)): 0.0221 / 0.3693
[1/5][0/1583]   Loss_D: 0.5365  Loss_G: 3.0092  D(x): 0.7437    D(G(z)): 0.1574 / 0.0836
[1/5][50/1583]  Loss_D: 0.4959  Loss_G: 5.4086  D(x): 0.9422    D(G(z)): 0.2960 / 0.0086
[1/5][100/1583] Loss_D: 0.2685  Loss_G: 3.6553  D(x): 0.8455    D(G(z)): 0.0640 / 0.0457
[1/5][150/1583] Loss_D: 0.6243  Loss_G: 4.6128  D(x): 0.8467    D(G(z)): 0.2878 / 0.0203
[1/5][200/1583] Loss_D: 0.4369  Loss_G: 2.8268  D(x): 0.7591    D(G(z)): 0.0871 / 0.0871
[1/5][250/1583] Loss_D: 0.4244  Loss_G: 3.7669  D(x): 0.8641    D(G(z)): 0.1952 / 0.0369
[1/5][300/1583] Loss_D: 0.7487  Loss_G: 2.5417  D(x): 0.6388    D(G(z)): 0.0948 / 0.1263
[1/5][350/1583] Loss_D: 0.5359  Loss_G: 2.9435  D(x): 0.6996    D(G(z)): 0.0836 / 0.0864
[1/5][400/1583] Loss_D: 0.3469  Loss_G: 2.7581  D(x): 0.8046    D(G(z)): 0.0755 / 0.1036
[1/5][450/1583] Loss_D: 0.5065  Loss_G: 2.8547  D(x): 0.7491    D(G(z)): 0.1494 / 0.0879
[1/5][500/1583] Loss_D: 0.3959  Loss_G: 3.3236  D(x): 0.8292    D(G(z)): 0.1328 / 0.0554
[1/5][550/1583] Loss_D: 0.6679  Loss_G: 5.8782  D(x): 0.9178    D(G(z)): 0.3802 / 0.0075
[1/5][600/1583] Loss_D: 0.8844  Loss_G: 1.9449  D(x): 0.5367    D(G(z)): 0.0326 / 0.1984
[1/5][650/1583] Loss_D: 0.8474  Loss_G: 2.0978  D(x): 0.6395    D(G(z)): 0.1883 / 0.1803
[1/5][700/1583] Loss_D: 0.4682  Loss_G: 5.1056  D(x): 0.8963    D(G(z)): 0.2520 / 0.0137
[1/5][750/1583] Loss_D: 0.4315  Loss_G: 4.0099  D(x): 0.8957    D(G(z)): 0.2441 / 0.0304
[1/5][800/1583] Loss_D: 0.4492  Loss_G: 4.1587  D(x): 0.9090    D(G(z)): 0.2656 / 0.0231
[1/5][850/1583] Loss_D: 0.7694  Loss_G: 1.2065  D(x): 0.5726    D(G(z)): 0.0254 / 0.3785
[1/5][900/1583] Loss_D: 0.3543  Loss_G: 4.0476  D(x): 0.8919    D(G(z)): 0.1873 / 0.0284
[1/5][950/1583] Loss_D: 0.5111  Loss_G: 2.3574  D(x): 0.7082    D(G(z)): 0.0835 / 0.1288
[1/5][1000/1583]        Loss_D: 0.5802  Loss_G: 5.4608  D(x): 0.9395    D(G(z)): 0.3649 / 0.0077
[1/5][1050/1583]        Loss_D: 1.0051  Loss_G: 2.4068  D(x): 0.5352    D(G(z)): 0.0322 / 0.1486
[1/5][1100/1583]        Loss_D: 0.3509  Loss_G: 3.6524  D(x): 0.9101    D(G(z)): 0.2070 / 0.0387
[1/5][1150/1583]        Loss_D: 0.9412  Loss_G: 5.4059  D(x): 0.9597    D(G(z)): 0.5325 / 0.0080
[1/5][1200/1583]        Loss_D: 0.5332  Loss_G: 3.1298  D(x): 0.7943    D(G(z)): 0.2138 / 0.0630
[1/5][1250/1583]        Loss_D: 0.6025  Loss_G: 3.5758  D(x): 0.8679    D(G(z)): 0.3182 / 0.0428
[1/5][1300/1583]        Loss_D: 0.7154  Loss_G: 2.1555  D(x): 0.5657    D(G(z)): 0.0379 / 0.1685
[1/5][1350/1583]        Loss_D: 0.4168  Loss_G: 2.1878  D(x): 0.7452    D(G(z)): 0.0645 / 0.1534
[1/5][1400/1583]        Loss_D: 0.8991  Loss_G: 5.3523  D(x): 0.9256    D(G(z)): 0.4967 / 0.0074
[1/5][1450/1583]        Loss_D: 0.4778  Loss_G: 3.8499  D(x): 0.8844    D(G(z)): 0.2655 / 0.0350
[1/5][1500/1583]        Loss_D: 0.5049  Loss_G: 2.5450  D(x): 0.7880    D(G(z)): 0.1906 / 0.1010
[1/5][1550/1583]        Loss_D: 1.0468  Loss_G: 1.9007  D(x): 0.4378    D(G(z)): 0.0346 / 0.2260
[2/5][0/1583]   Loss_D: 0.5008  Loss_G: 3.5294  D(x): 0.9006    D(G(z)): 0.2844 / 0.0466
[2/5][50/1583]  Loss_D: 0.5024  Loss_G: 2.3252  D(x): 0.7413    D(G(z)): 0.1450 / 0.1267
[2/5][100/1583] Loss_D: 0.7520  Loss_G: 2.0230  D(x): 0.5753    D(G(z)): 0.0835 / 0.1797
[2/5][150/1583] Loss_D: 0.3734  Loss_G: 2.7221  D(x): 0.8502    D(G(z)): 0.1689 / 0.0889
[2/5][200/1583] Loss_D: 0.5891  Loss_G: 2.6314  D(x): 0.7453    D(G(z)): 0.2076 / 0.1032
[2/5][250/1583] Loss_D: 1.1471  Loss_G: 3.5814  D(x): 0.8959    D(G(z)): 0.5563 / 0.0545
[2/5][300/1583] Loss_D: 0.5756  Loss_G: 3.1905  D(x): 0.8738    D(G(z)): 0.3128 / 0.0605
[2/5][350/1583] Loss_D: 0.5971  Loss_G: 2.9928  D(x): 0.8177    D(G(z)): 0.2657 / 0.0739
[2/5][400/1583] Loss_D: 0.6856  Loss_G: 3.8514  D(x): 0.8880    D(G(z)): 0.3835 / 0.0298
[2/5][450/1583] Loss_D: 0.6088  Loss_G: 1.7919  D(x): 0.6660    D(G(z)): 0.1227 / 0.2189
[2/5][500/1583] Loss_D: 0.7147  Loss_G: 2.6453  D(x): 0.8321    D(G(z)): 0.3531 / 0.1007
[2/5][550/1583] Loss_D: 0.5759  Loss_G: 2.9074  D(x): 0.8269    D(G(z)): 0.2833 / 0.0738
[2/5][600/1583] Loss_D: 0.5678  Loss_G: 2.6149  D(x): 0.7928    D(G(z)): 0.2516 / 0.0956
[2/5][650/1583] Loss_D: 0.9501  Loss_G: 1.1814  D(x): 0.5916    D(G(z)): 0.2322 / 0.3815
[2/5][700/1583] Loss_D: 0.4551  Loss_G: 2.5074  D(x): 0.8331    D(G(z)): 0.2047 / 0.1129
[2/5][750/1583] Loss_D: 0.4560  Loss_G: 2.3947  D(x): 0.7525    D(G(z)): 0.1240 / 0.1147
[2/5][800/1583] Loss_D: 1.1853  Loss_G: 5.1657  D(x): 0.9202    D(G(z)): 0.6049 / 0.0091
[2/5][850/1583] Loss_D: 0.5514  Loss_G: 3.0085  D(x): 0.8497    D(G(z)): 0.2890 / 0.0685
[2/5][900/1583] Loss_D: 0.6882  Loss_G: 1.8971  D(x): 0.6970    D(G(z)): 0.2332 / 0.1909
[2/5][950/1583] Loss_D: 1.1220  Loss_G: 0.7904  D(x): 0.4095    D(G(z)): 0.0570 / 0.4975
[2/5][1000/1583]        Loss_D: 1.3335  Loss_G: 0.3115  D(x): 0.3347    D(G(z)): 0.0262 / 0.7661
[2/5][1050/1583]        Loss_D: 1.7281  Loss_G: 0.8212  D(x): 0.2437    D(G(z)): 0.0261 / 0.5179
[2/5][1100/1583]        Loss_D: 0.9401  Loss_G: 3.7894  D(x): 0.9033    D(G(z)): 0.5104 / 0.0349
[2/5][1150/1583]        Loss_D: 0.8078  Loss_G: 3.9862  D(x): 0.9178    D(G(z)): 0.4608 / 0.0286
[2/5][1200/1583]        Loss_D: 0.5182  Loss_G: 3.1859  D(x): 0.8568    D(G(z)): 0.2787 / 0.0554
[2/5][1250/1583]        Loss_D: 0.5092  Loss_G: 2.3530  D(x): 0.8015    D(G(z)): 0.2122 / 0.1188
[2/5][1300/1583]        Loss_D: 1.2668  Loss_G: 0.5543  D(x): 0.3424    D(G(z)): 0.0165 / 0.6271
[2/5][1350/1583]        Loss_D: 0.7197  Loss_G: 3.8595  D(x): 0.9043    D(G(z)): 0.4208 / 0.0299
[2/5][1400/1583]        Loss_D: 0.5428  Loss_G: 2.6526  D(x): 0.8873    D(G(z)): 0.3056 / 0.0961
[2/5][1450/1583]        Loss_D: 0.6610  Loss_G: 4.2385  D(x): 0.9272    D(G(z)): 0.3985 / 0.0211
[2/5][1500/1583]        Loss_D: 0.8172  Loss_G: 3.2164  D(x): 0.8811    D(G(z)): 0.4422 / 0.0612
[2/5][1550/1583]        Loss_D: 0.6449  Loss_G: 3.8452  D(x): 0.9130    D(G(z)): 0.3813 / 0.0325
[3/5][0/1583]   Loss_D: 0.7677  Loss_G: 1.7745  D(x): 0.5928    D(G(z)): 0.1388 / 0.2182
[3/5][50/1583]  Loss_D: 0.7981  Loss_G: 2.9624  D(x): 0.8315    D(G(z)): 0.4131 / 0.0735
[3/5][100/1583] Loss_D: 0.5679  Loss_G: 1.8958  D(x): 0.7173    D(G(z)): 0.1667 / 0.1914
[3/5][150/1583] Loss_D: 0.8576  Loss_G: 1.5904  D(x): 0.5391    D(G(z)): 0.1158 / 0.2699
[3/5][200/1583] Loss_D: 0.8644  Loss_G: 1.6487  D(x): 0.5868    D(G(z)): 0.1933 / 0.2319
[3/5][250/1583] Loss_D: 0.5331  Loss_G: 3.0401  D(x): 0.8831    D(G(z)): 0.3022 / 0.0608
[3/5][300/1583] Loss_D: 1.2449  Loss_G: 2.9489  D(x): 0.8759    D(G(z)): 0.5865 / 0.0828
[3/5][350/1583] Loss_D: 1.7188  Loss_G: 0.5466  D(x): 0.2664    D(G(z)): 0.0539 / 0.6320
[3/5][400/1583] Loss_D: 0.5794  Loss_G: 2.7556  D(x): 0.7984    D(G(z)): 0.2640 / 0.0787
[3/5][450/1583] Loss_D: 0.6916  Loss_G: 3.1434  D(x): 0.8813    D(G(z)): 0.3955 / 0.0578
[3/5][500/1583] Loss_D: 0.8415  Loss_G: 1.9770  D(x): 0.6981    D(G(z)): 0.3120 / 0.1639
[3/5][550/1583] Loss_D: 0.6394  Loss_G: 2.4790  D(x): 0.8093    D(G(z)): 0.2990 / 0.1082
[3/5][600/1583] Loss_D: 0.7545  Loss_G: 1.6259  D(x): 0.6042    D(G(z)): 0.1454 / 0.2401
[3/5][650/1583] Loss_D: 0.5494  Loss_G: 2.1957  D(x): 0.8292    D(G(z)): 0.2727 / 0.1414
[3/5][700/1583] Loss_D: 1.5095  Loss_G: 5.1368  D(x): 0.9269    D(G(z)): 0.6897 / 0.0095
[3/5][750/1583] Loss_D: 0.4714  Loss_G: 2.1401  D(x): 0.8137    D(G(z)): 0.2101 / 0.1501
[3/5][800/1583] Loss_D: 0.7118  Loss_G: 3.2356  D(x): 0.8190    D(G(z)): 0.3579 / 0.0540
[3/5][850/1583] Loss_D: 0.6392  Loss_G: 1.6740  D(x): 0.6650    D(G(z)): 0.1402 / 0.2391
[3/5][900/1583] Loss_D: 0.5303  Loss_G: 2.8854  D(x): 0.7900    D(G(z)): 0.2204 / 0.0740
[3/5][950/1583] Loss_D: 0.6333  Loss_G: 2.1030  D(x): 0.6946    D(G(z)): 0.1882 / 0.1572
[3/5][1000/1583]        Loss_D: 0.8715  Loss_G: 1.6630  D(x): 0.5222    D(G(z)): 0.0890 / 0.2590
[3/5][1050/1583]        Loss_D: 0.6139  Loss_G: 3.1772  D(x): 0.8609    D(G(z)): 0.3400 / 0.0558
[3/5][1100/1583]        Loss_D: 0.6673  Loss_G: 3.4143  D(x): 0.9044    D(G(z)): 0.3910 / 0.0435
[3/5][1150/1583]        Loss_D: 0.6554  Loss_G: 3.4282  D(x): 0.8429    D(G(z)): 0.3347 / 0.0484
[3/5][1200/1583]        Loss_D: 0.6184  Loss_G: 1.7371  D(x): 0.6531    D(G(z)): 0.1177 / 0.2132
[3/5][1250/1583]        Loss_D: 0.8293  Loss_G: 3.1246  D(x): 0.7821    D(G(z)): 0.3883 / 0.0594
[3/5][1300/1583]        Loss_D: 0.5211  Loss_G: 2.0112  D(x): 0.7308    D(G(z)): 0.1503 / 0.1637
[3/5][1350/1583]        Loss_D: 0.7389  Loss_G: 1.4238  D(x): 0.5854    D(G(z)): 0.1181 / 0.2935
[3/5][1400/1583]        Loss_D: 0.6608  Loss_G: 3.1928  D(x): 0.7803    D(G(z)): 0.2922 / 0.0580
[3/5][1450/1583]        Loss_D: 0.6381  Loss_G: 3.4123  D(x): 0.8340    D(G(z)): 0.3337 / 0.0450
[3/5][1500/1583]        Loss_D: 0.7027  Loss_G: 3.1943  D(x): 0.9058    D(G(z)): 0.4113 / 0.0556
[3/5][1550/1583]        Loss_D: 0.6849  Loss_G: 2.9714  D(x): 0.8258    D(G(z)): 0.3499 / 0.0704
[4/5][0/1583]   Loss_D: 0.7685  Loss_G: 1.7204  D(x): 0.5788    D(G(z)): 0.1084 / 0.2252
[4/5][50/1583]  Loss_D: 0.6194  Loss_G: 1.4702  D(x): 0.6214    D(G(z)): 0.0700 / 0.2812
[4/5][100/1583] Loss_D: 0.5243  Loss_G: 2.4332  D(x): 0.8206    D(G(z)): 0.2515 / 0.1099
[4/5][150/1583] Loss_D: 0.8506  Loss_G: 1.0129  D(x): 0.5094    D(G(z)): 0.0647 / 0.4126
[4/5][200/1583] Loss_D: 1.1715  Loss_G: 2.5120  D(x): 0.5642    D(G(z)): 0.3481 / 0.1214
[4/5][250/1583] Loss_D: 0.4317  Loss_G: 2.7731  D(x): 0.8405    D(G(z)): 0.2088 / 0.0791
[4/5][300/1583] Loss_D: 1.2310  Loss_G: 0.4177  D(x): 0.3812    D(G(z)): 0.0576 / 0.6799
[4/5][350/1583] Loss_D: 0.5565  Loss_G: 2.7405  D(x): 0.8525    D(G(z)): 0.3005 / 0.0810
[4/5][400/1583] Loss_D: 0.4918  Loss_G: 3.5705  D(x): 0.8863    D(G(z)): 0.2833 / 0.0371
[4/5][450/1583] Loss_D: 0.6403  Loss_G: 2.7691  D(x): 0.8543    D(G(z)): 0.3406 / 0.0812
[4/5][500/1583] Loss_D: 0.5944  Loss_G: 1.4696  D(x): 0.6849    D(G(z)): 0.1325 / 0.2682
[4/5][550/1583] Loss_D: 0.8678  Loss_G: 4.1990  D(x): 0.9529    D(G(z)): 0.5105 / 0.0202
[4/5][600/1583] Loss_D: 0.8326  Loss_G: 1.1841  D(x): 0.5175    D(G(z)): 0.0679 / 0.3628
[4/5][650/1583] Loss_D: 0.5198  Loss_G: 2.4393  D(x): 0.7668    D(G(z)): 0.1943 / 0.1148
[4/5][700/1583] Loss_D: 0.8029  Loss_G: 4.0836  D(x): 0.8791    D(G(z)): 0.4448 / 0.0229
[4/5][750/1583] Loss_D: 0.8636  Loss_G: 2.0386  D(x): 0.5234    D(G(z)): 0.0899 / 0.1846
[4/5][800/1583] Loss_D: 0.5041  Loss_G: 3.0354  D(x): 0.8302    D(G(z)): 0.2301 / 0.0609
[4/5][850/1583] Loss_D: 0.7514  Loss_G: 1.2513  D(x): 0.5578    D(G(z)): 0.0899 / 0.3480
[4/5][900/1583] Loss_D: 0.6650  Loss_G: 1.2806  D(x): 0.6675    D(G(z)): 0.1925 / 0.3201
[4/5][950/1583] Loss_D: 0.5754  Loss_G: 3.0898  D(x): 0.8730    D(G(z)): 0.3233 / 0.0597
[4/5][1000/1583]        Loss_D: 0.9327  Loss_G: 0.7588  D(x): 0.4674    D(G(z)): 0.0434 / 0.5174
[4/5][1050/1583]        Loss_D: 0.9255  Loss_G: 0.9513  D(x): 0.5029    D(G(z)): 0.1161 / 0.4196
[4/5][1100/1583]        Loss_D: 0.6573  Loss_G: 3.4663  D(x): 0.8755    D(G(z)): 0.3674 / 0.0403
[4/5][1150/1583]        Loss_D: 0.9803  Loss_G: 1.2451  D(x): 0.4602    D(G(z)): 0.0978 / 0.3432
[4/5][1200/1583]        Loss_D: 0.5560  Loss_G: 2.5421  D(x): 0.7617    D(G(z)): 0.2097 / 0.1020
[4/5][1250/1583]        Loss_D: 0.7573  Loss_G: 1.9034  D(x): 0.6477    D(G(z)): 0.2158 / 0.1890
[4/5][1300/1583]        Loss_D: 0.4733  Loss_G: 2.7071  D(x): 0.8271    D(G(z)): 0.2169 / 0.0882
[4/5][1350/1583]        Loss_D: 1.0812  Loss_G: 1.1500  D(x): 0.5225    D(G(z)): 0.2278 / 0.3626
[4/5][1400/1583]        Loss_D: 1.5454  Loss_G: 5.2881  D(x): 0.9620    D(G(z)): 0.7085 / 0.0089
[4/5][1450/1583]        Loss_D: 0.3576  Loss_G: 3.1023  D(x): 0.8687    D(G(z)): 0.1726 / 0.0584
[4/5][1500/1583]        Loss_D: 0.5330  Loss_G: 1.9979  D(x): 0.7277    D(G(z)): 0.1597 / 0.1680
[4/5][1550/1583]        Loss_D: 0.8927  Loss_G: 4.1379  D(x): 0.9345    D(G(z)): 0.5081 / 0.0224

5.6 结果

从三个不同方面看实验结果:

  • 看G和D两个损失函数的变化
  • 看每轮epoch训练G生成图片的结果
  • 对比一批生成图片和一批真实图片(64张)

a.loss变化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

DCGAN的PyTorch实现_第3张图片
b.图片生成变化

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

c.对比真假图片

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

DCGAN的PyTorch实现_第4张图片

6.下一步

  • Train for longer to see how good the results get

多训练几次,如增加epoch看效果

  • Modify this model to take a different dataset and possibly change the size of the images and the model architecture

换其他数据集、或者调整一些模型结构

  • Check out some other cool GAN projects here

试试其他有趣的GAN应用–https://github.com/nashory/gans-awesome-applications

  • Create GANs that generate music

用GAN生成音乐

7.参考:

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#

https://github.com/soumith/ganhacks#authors

你可能感兴趣的:(DCGAN,GAN,PyTorch)