2020 cs231n 作业3 笔记 Generative_Adversarial_Networks_PyTorch

Generative Adversarial Networks

论文地址:Generative Adversarial Networks

对抗生成网络(GAN)的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。

判别器D的目标:对输入的图片正确进行判别为真图片还是假图片。

生成器G的目标:生成假图片,但是能让判别器判断为真。

所以一方面:

要最大化生成器生成的图片被判别为真的概率

具体就是判别器D对输入为G(z)的数据,判别为真的概率,即最大化D(G(z))

另一方面:

要最大化判别器对真数据判别为真的概率(D(x))和对假数据判别为假的概率(1-D(G(z)))

通俗点说,可以吧生成器G看作图片的生产者,判别器D看作质量检验员

      生成器G生成图片,给判别器D看,“哥,你看我这图片质量过关不?(像真的不?)”

       判别器D看了一眼说:“小老弟,质量还不行,回去再把图片优化下”

       生成器G:“好咧,我这就去”

       ……

就这样多次循环,使得生成器G生产出高质量的图片(像真的)。

Random Noise

生成器G是从以噪声为初始数据来生成图片。

def sample_noise(batch_size, dim, seed=None):
    """
    Generate a PyTorch Tensor of uniform random noise.

    Input: 
    - batch_size: Integer giving the batch size of noise to generate.
    - dim: Integer giving the dimension of noise to generate.
    
    Output:
    - A PyTorch Tensor of shape (batch_size, dim) containing uniform
      random noise in the range (-1, 1).
    """
    if seed is not None:
        torch.manual_seed(seed)
        
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    out = torch.rand(batch_size, dim)
    out = 2 * out - 1
    return out
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

Discriminator

构建判别器D网络模型:

def discriminator(seed=None):
    """
    Build and return a PyTorch model implementing the architecture above.
    """

    if seed is not None:
        torch.manual_seed(seed)

    model = None

    ##############################################################################
    # TODO: Implement architecture                                               #
    #                                                                            #
    # HINT: nn.Sequential might be helpful.                                      #
    ##############################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    model = nn.Sequential(
      Flatten(),
      nn.Linear(784, 256),
      nn.LeakyReLU(0.01),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.01),
      nn.Linear(256, 1),
    )

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return model

Generator

构建生成器G网络模型

def generator(noise_dim=NOISE_DIM, seed=None):
    """
    Build and return a PyTorch model implementing the architecture above.
    """

    if seed is not None:
        torch.manual_seed(seed)

    model = None

    ##############################################################################
    # TODO: Implement architecture                                               #
    #                                                                            #
    # HINT: nn.Sequential might be helpful.                                      #
    ##############################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    model = nn.Sequential(
      nn.Linear(noise_dim, 1024),
      nn.ReLU(),
      nn.Linear(1024, 1024),
      nn.ReLU(),
      nn.Linear(1024, 784),
      nn.Tanh(),
    )

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################
    return model

GAN Loss

1、判别器D loss

def discriminator_loss(logits_real, logits_fake):
    """
    Computes the discriminator loss described above.
    
    Inputs:
    - logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the discriminator.
    """
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    #分类器需要判断真数据的标签是1
    real_labels = torch.ones_like(logits_real).type(dtype)
    #分类器需要判断假数据的标签是0
    fake_labels = torch.zeros_like(logits_fake).type(dtype)
    #分类器判断真数据为1的loss
    real_loss = bce_loss(logits_real, real_labels)
    #分类器判断假数据为0的loss
    fake_loss = bce_loss(logits_fake, fake_labels)
    #分类器总的loss
    loss = real_loss + fake_loss

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return loss

2、生成器G loss

def generator_loss(logits_fake):
    """
    Computes the generator loss described above.

    Inputs:
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Returns:
    - loss: PyTorch Tensor containing the (scalar) loss for the generator.
    """
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    #生成器需要生成标签被判断为1的假数据
    fake_labels = torch.ones_like(logits_fake).type(dtype)
    #生成器生成标签为1的假数据的loss
    loss = bce_loss(logits_fake, fake_labels)

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return loss

Optimizing our loss

使用优化算法:

def get_optimizer(model):
    """
    Construct and return an Adam optimizer for the model with learning rate 1e-3,
    beta1=0.5, and beta2=0.999.
    
    Input:
    - model: A PyTorch model that we want to optimize.
    
    Returns:
    - An Adam optimizer for the model with the desired hyperparameters.
    """
    optimizer = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    optimizer = optim.Adam(model.parameters(),lr=1e-3,betas=(0.5,0.999))

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return optimizer

Training a GAN!

gan的最后的训练结果输出如下:

Iter: 3750

2020 cs231n 作业3 笔记 Generative_Adversarial_Networks_PyTorch_第1张图片

Least Squares GAN

使用另一种loss的计算方法:

1、生成器G loss

def ls_generator_loss(scores_fake):
    """
    Computes the Least-Squares GAN loss for the generator.
    
    Inputs:
    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Outputs:
    - loss: A PyTorch Tensor containing the loss.
    """
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    loss = 0.5 * (scores_fake - 1).pow(2)
    loss = loss.mean()

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return loss

2、判别器D loss

def ls_discriminator_loss(scores_real, scores_fake):
    """
    Compute the Least-Squares GAN loss for the discriminator.
    
    Inputs:
    - scores_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - scores_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
    
    Outputs:
    - loss: A PyTorch Tensor containing the loss.
    """
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    loss = 0.5 * (scores_real - 1).pow(2) + 0.5 * scores_fake.pow(2)
    loss = loss.mean()
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return loss

最后输出的训练结果如下:

Iter: 3750

2020 cs231n 作业3 笔记 Generative_Adversarial_Networks_PyTorch_第2张图片

Deeply Convolutional GANs

这是添加了卷积层的深层训练模型

1、判别器D 模型

def build_dc_classifier(batch_size):
    """
    Build and return a PyTorch model for the DCGAN discriminator implementing
    the architecture above.
    """

    ##############################################################################
    # TODO: Implement architecture                                               #
    #                                                                            #
    # HINT: nn.Sequential might be helpful.                                      #
    ##############################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    model = nn.Sequential(
      #Unflatten(),
      nn.Conv2d(1,32,kernel_size=5),
      nn.LeakyReLU(0.01),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32,64,kernel_size=5),
      nn.LeakyReLU(0.01),
      nn.MaxPool2d(kernel_size=2,stride=2),
      Flatten(),
      nn.Linear(64*4*4,4*4*64),
      nn.LeakyReLU(0.01),
      nn.Linear(64*4*4,1),
    )
    return model
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################

2、生成器G 模型

def build_dc_generator(noise_dim=NOISE_DIM):
    """
    Build and return a PyTorch model implementing the DCGAN generator using
    the architecture described above.
    """

    ##############################################################################
    # TODO: Implement architecture                                               #
    #                                                                            #
    # HINT: nn.Sequential might be helpful.                                      #
    ##############################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    model = nn.Sequential(
      nn.Linear(noise_dim,1024),
      nn.ReLU(),
      nn.BatchNorm1d(1024),
      nn.Linear(1024,7*7*128),
      nn.ReLU(),
      nn.BatchNorm1d(7*7*128),
      Unflatten(),
      nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
      nn.ReLU(),
      nn.BatchNorm2d(64),
      nn.ConvTranspose2d(64,1,kernel_size=4,stride=2,padding=1),
      nn.Tanh(),
      Flatten(),
    )
    return model
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    ##############################################################################
    #                               END OF YOUR CODE                             #
    ##############################################################################

输出训练的最后结果:

Iter: 1750

2020 cs231n 作业3 笔记 Generative_Adversarial_Networks_PyTorch_第3张图片

 

 

你可能感兴趣的:(cs231n,神经网络,深度学习,GAN,pytorch)