Pix2Pix原理解析

1.网络搭建

class UnetGenerator(nn.Module):
    """Create a Unet-based generator"""

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
        """Construct a Unet generator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
                                image of size 128x128 will become of size 1x1 # at the bottleneck
            ngf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer

        We construct the U-Net from the innermost layer to the outermost layer.
        It is a recursive process.
        """
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

    def forward(self, input):
        """Standard forward"""
        return self.model(input)

Unet的模型结构如下图示,因此是从最内层开始搭建:

Pix2Pix原理解析_第1张图片

经过第一行后,网络结构如下,也就是最内层的下采样->上采样。

Pix2Pix原理解析_第2张图片

之后有一个循环,经过第一次循环后,在上一层的外围再次搭建了下采样和上采样:

Pix2Pix原理解析_第3张图片

经过第二次循环:

Pix2Pix原理解析_第4张图片

经过第三次循环:

Pix2Pix原理解析_第5张图片

可以看到每次反卷积的输入特征图的channel是1024,是因为它除了要接受上一层反卷积的输出(512维度),还要接受与其特征图大小相同的下采样层的输出(512维度),因此是1024的维度数。

循环完毕后,再次添加四次外部的降采样和反卷积,最终的网络结构如下:

UnetGenerator(
  (model): UnetSkipConnectionBlock(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetSkipConnectionBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetSkipConnectionBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                  (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                  (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (3): UnetSkipConnectionBlock(
                    (model): Sequential(
                      (0): LeakyReLU(negative_slope=0.2, inplace=True)
                      (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                      (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                      (3): UnetSkipConnectionBlock(
                        (model): Sequential(
                          (0): LeakyReLU(negative_slope=0.2, inplace=True)
                          (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                          (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                          (3): UnetSkipConnectionBlock(
                            (model): Sequential(
                              (0): LeakyReLU(negative_slope=0.2, inplace=True)
                              (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                              (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                              (3): UnetSkipConnectionBlock(
                                (model): Sequential(
                                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                                  (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                  (2): ReLU(inplace=True)
                                  (3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                  (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                )
                              )
                              (4): ReLU(inplace=True)
                              (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                              (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                              (7): Dropout(p=0.5, inplace=False)
                            )
                          )
                          (4): ReLU(inplace=True)
                          (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                          (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                          (7): Dropout(p=0.5, inplace=False)
                        )
                      )
                      (4): ReLU(inplace=True)
                      (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                      (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                      (7): Dropout(p=0.5, inplace=False)
                    )
                  )
                  (4): ReLU(inplace=True)
                  (5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                  (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                )
              )
              (4): ReLU(inplace=True)
              (5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (4): ReLU(inplace=True)
          (5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): ReLU(inplace=True)
      (3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): Tanh()
    )
  )
)

2.反向传播过程

我们这里假定pix2pix是风格A2B,风格A就是左边的图,风格B是右边的图。

Pix2Pix原理解析_第6张图片Pix2Pix原理解析_第7张图片

反向传播的代码如下,整个是先更新D再更新G。

(1)首先向前传播,输入A,经过G,得到fakeB;

(2)开始更新D,进入backward_D函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为0;
  • 将A与real B cat起来,cat的整体相当于positive img,送入D,得到real_fake;
  • 计算pred_real的GAN损失,标签为1;
  • fake和real的GAN相加,得到总的判别器GAN损失。

(3)开始更新G,进入backward_G函数:

  • 将A和fakeB cat起来,cat的整体相当于下图中的negative img,送入D,得到pred_fake;
  • 计算pred_fake的GAN损失,标签为1;
  • 计算real B和fake B的逐像素损失L1;
  • 将GAN损失和逐像素损失L1相加,得到总损失。

下图就可视化了上述的过程。

Pix2Pix原理解析_第8张图片

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()                   # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        self.backward_D()                # calculate gradients for D
        self.optimizer_D.step()          # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward_G()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights

3.PatchGAN

pix2pix还对判别器的结构做了一定的改动。之前都是对整张图像输出一个是否为真实的概率。pix2pix提出了PatchGan的概念。PatchGAN对图片中的每一个N×N的小块(patch)计算概率,然后再将这些概率求平均值作为整体的输出。

在上面的代码中pred_fake = self.netD(fake_AB.detach())的输出就不是一个概率值,而是30×30的特征图,相当于有30×30个patch。

下图表示标准的D网络结构(n_layers = 3),n_layers 为主要的特征卷积层数为3。如何理解?

  • 下面(0)(1)表示head conv层,不算在n_layers layer中;
  • (2)(3)(4)才算做是标准的一个n_layers层,因此2-4、5-7、8-10一共是3层。
  • 最后有一个卷积层,channel维度为1。

需要注意一下,patchgan channel维度最大为512。

DataParallel(
  (module): NLayerDiscriminator(
    (model): Sequential(
      (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace=True)
      (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    )
  )
)

具体代码如下。与我们前面所述的稍微有些不一样,按照前面所述for n in range(1, n_layers)中相当于构建n_layers个特征提取层。但是代码中实际上构建了n_layers-1个,最后一个标准的特征提取层放在了sequence +=[...]中。

但是理解上还是可以按照前面。在spade框架中,就重新了构建patchgan的过程,其中就把最后一个标准的特征提取层也通过for n in range(1, n_layers)构建了。见https://github.com/NVlabs/SPADE/blob/master/models/networks/discriminator.py

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        kw = 4  #卷积核的大小
        padw = 1  #pading
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]  #head conv
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # channel = 1
        self.model = nn.Sequential(*sequence)

4.与CGAN的不同之处

下面这张图是CGAN的示意图。可以看到

  • 在CGAN模型中,生成器的输入有两个,分别为一个噪声z,以及对应的条件y(在mnist训练中将图像和标签concat在一起),输出为符合该条件的图像G(z|y)
  • 判别器的输入同样也为两个,一个是条件,另一个满足该条件的真实图像x。

pix2pix模型与CGAN最大的不同在于,不再输入噪声z。因为实验中,即便给G输入一个噪声z,G也只学会将其忽略并生成图像,噪声z对输出结果的影响几乎微乎其微。因此为了简洁性,将z去掉了。

pix2pix模型中G的输入实际上等于CGAN模型的条件y

你可能感兴趣的:(GAN,GAN)