Pose-Transfer代码阅读笔记

一、简介

笔者阅读的Pose-Transfer代码为https://github.com/tengteng95/Pose-Transfer的Pytorch_v1.0分支,适应于pytorch1.x的版本。以下讲的流程为ReadMe中给出的运行参数的情况下的流程,它是论文Progressive Pose Attention for Person Image Generation in CVPR19 (Oral)的代码。

二、网络结构

神经网络相关的代码阅读入口为models/PATN.py中的class TransferModel

2.1 生成网络netG

定义了生成网络netG的的代码为:


 netG = PATNetwork(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                                           n_blocks=9, gpu_ids=gpu_ids, n_downsampling=n_downsampling)

参数:input_nc(输入channel数),output_nc(输出channel数),ngf(channel数相关,可理解为特征图个数,生成网络基于此数的倍数进行channel变化), n_blocks(PATB个数),n_downsampling(下采样卷积个数)
网络结构相关代码如下:
先看PATNetwork的forward前向计算代码:

 def forward(self, input): # x from stream 1 and stream 2
        # here x should be a tuple
        x1, x2 = input
        # down_sample
        x1 = self.stream1_down(x1)
        x2 = self.stream2_down(x2)
        # att_block
        for model in self.att:
            x1, x2, _ = model(x1, x2)

        # up_sample
        x1 = self.stream1_up(x1)

        return x1

也就是生成网络大致可分为下采样部分、att_block部分、上采样部分。att_block的上面分支的最后输出经过上采样为最终结果。

2.1.1 下采样部分

1.先是Padding层:

model_stream1_down = [nn.ReflectionPad2d(3),
                    nn.Conv2d(self.input_nc_s1, ngf, kernel_size=7, padding=0,
                           bias=use_bias),
                    norm_layer(ngf),
                    nn.ReLU(True)]

 model_stream2_down = [nn.ReflectionPad2d(3),
                    nn.Conv2d(self.input_nc_s2, ngf, kernel_size=7, padding=0,
                           bias=use_bias),
                    norm_layer(ngf),
                    nn.ReLU(True)]

2.n_downsampling个下采样卷积层:

        for i in range(n_downsampling):
            mult = 2**i
            model_stream1_down += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                            norm_layer(ngf * mult * 2),
                            nn.ReLU(True)]
            model_stream2_down += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                            norm_layer(ngf * mult * 2),
                            nn.ReLU(True)]

3.链接起层,赋值

        self.stream1_down = nn.Sequential(*model_stream1_down)
        self.stream2_down = nn.Sequential(*model_stream2_down)

2.1.2 att block部分,即对应论文中的PATB,即Pose-Attentional Transfer Network。

贴一张论文里的图:


pose-transfer网络结构.png

后文中讲的PATB的第一分支就是上面的分支,第二分支就是下面的分支。

mult = 2**n_downsampling
        cated_stream2 = [True for i in range(n_blocks)]
        cated_stream2[0] = False
        attBlock = nn.ModuleList()
        for i in range(n_blocks):
            attBlock.append(PATBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,            use_dropout=use_dropout, use_bias=use_bias, cated_stream2=cated_stream2[i]))

也就是n_blocks个PATB块,一个PATB块构成为:
首先是变量定义,conv_blocks用于存储各个层:

 conv_block = []
 p = 0

前向计算forward函数代码如下:

    def forward(self, x1, x2):
        x1_out = self.conv_block_stream1(x1)
        x2_out = self.conv_block_stream2(x2)
        # att = F.sigmoid(x2_out)
        att = torch.sigmoid(x2_out)

        x1_out = x1_out * att
        out = x1 + x1_out # residual connection

        # stream2 receive feedback from stream1
        x2_out = torch.cat((x2_out, out), 1)
        return out, x2_out, x1_out

可以看出是两个输入,两个输出,结合论文图示看更易理解。在整个生成网络的forward函数中,取out、x2_out进行下一步运算。
其中,conv_block_stream1与conv_block_stream2的构建代码为:

        self.conv_block_stream1 = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, cal_att=False)
        self.conv_block_stream2 = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, cal_att=True, cated_stream2=cated_stream2)

conv_block_stream1与conv_block_stream2结构具体的网络结构如下:
1.Padding

        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1

2.Normalize

        if cated_stream2:
            conv_block += [nn.Conv2d(dim*2, dim*2, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim*2),
                       nn.ReLU(True)]
        else:
            conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                           norm_layer(dim),
                           nn.ReLU(True)]

其中cated_stream2,在第一个分支为False,在第二个分支第一个PATB中为False,在第二个及以后中为True,因为PATB的第二个分支的最后输出为为第一个分支卷积结果和第二个分支卷积结果的拼接(x2_out = torch.cat((x2_out, out), 1))
3.dropout层(可选)

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

4.再次Padding
代码与第一次Padding相同
5.卷积层

        if cal_att:
            if cated_stream2:
                conv_block += [nn.Conv2d(dim*2, dim, kernel_size=3, padding=p, bias=use_bias)]
            else:
                conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
        else:
            conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

cal_att为False时是第一分支,cal_att为第二分支。第一分支输入输出channel均为dim,第二个分支则需要把dim*2的输入channel转成dim的输出channel,方便与第一分支进行拼接。

2.1.3 上采样部分

        model_stream1_up = []
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model_stream1_up += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                            norm_layer(int(ngf * mult / 2)),
                            nn.ReLU(True)]

大致就是n_downsampling个反卷积层的上采样。

2.2 分类网络

有两个分类网络,分别为netD_PB和netD_PP。netD_PB用于评判输出图片Pg和目标姿态St的的匹配程度(英文原文:how well Pg align with the target pose St(shape consistency).),netD_PP用于评判输出图片Pg是否包含输入图片Pc中的同一个人(英文原文:judge how likely Pg contains the same person in Pc (appearance consistency))
netD_PB和net_PP结构相同:

            use_sigmoid = opt.no_lsgan
            if opt.with_D_PB:
                self.netD_PB = networks.define_D(opt.P_input_nc+opt.BP_input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                            not opt.no_dropout_D,
                                            n_downsampling = opt.D_n_downsampling)

            if opt.with_D_PP:
                self.netD_PP = networks.define_D(opt.P_input_nc+opt.P_input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
                                            not opt.no_dropout_D,
                                            n_downsampling = opt.D_n_downsampling)

参数:input_nc(输入channel数),output_nv(输出channel数),ndf(channel数相关,可理解为特征图个数,分类网络基于此数的倍数进行channel变化),which_model_netD(分类器的基础网络,如resnet),n_layers_D(分类器中的block个数),norm(instance normalization or batch normalization),n_downsampling(下采样卷积个数)
代码中define_D提供的是ResnetDiscriminator。首先看ResnetDiscriminator的forward函数:

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

就是只有一个self.model跑输入得到输出即可。
self.model的结构为:
1.Padding

model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                           bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

2.下采样部分

 if n_downsampling <= 2:
            for i in range(n_downsampling):
                mult = 2 ** i
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True)]
        elif n_downsampling == 3:
            mult = 2 ** 0
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
            mult = 2 ** 1
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]
            mult = 2 ** 2
            model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult),
                      nn.ReLU(True)]

        if n_downsampling <= 2:
            mult = 2 ** n_downsampling
        else:
            mult = 4

就是凑出n_downsampling个下采样卷积层
3.残差块部分

        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                  use_bias=use_bias)]

ResnetBlock就是resnet中的Identity Block,不再展开叙述了
4.sigmoid层

        if use_sigmoid:
            model += [nn.Sigmoid()]

三、损失函数计算

在train.py中,调用的model.optimize_parameters()调整网络权重函数具体代码如下:

 # forward
        self.forward()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        # D_P
        if self.opt.with_D_PP:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_PP.zero_grad()
                self.backward_D_PP()
                self.optimizer_D_PP.step()

        # D_BP
        if self.opt.with_D_PB:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_PB.zero_grad()
                self.backward_D_PB()
                self.optimizer_D_PB.step()

其中的forward函数为:

    def forward(self):
        G_input = [self.input_P1,
                   torch.cat((self.input_BP1, self.input_BP2), 1)]
        self.fake_p2 = self.netG(G_input)

总结一下就是分为以下几步:

3.1 前向计算生成网络G得到生成图片self.fake_p2

3.2 给G网络调参,即向后传播

backward_G()代码为:

    def backward_G(self):
        if self.opt.with_D_PB:
            pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2), 1))
            self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True)

        if self.opt.with_D_PP:
            pred_fake_PP = self.netD_PP(torch.cat((self.fake_p2, self.input_P1), 1))
            self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True)

        # L1 loss
        if self.opt.L1_type == 'l1_plus_perL1' :
            losses = self.criterionL1(self.fake_p2, self.input_P2)
            self.loss_G_L1 = losses[0]
            self.loss_originL1 = losses[1].item()
            self.loss_perceptual = losses[2].item()
        else:
            self.loss_G_L1 = self.criterionL1(self.fake_p2, self.input_P2) * self.opt.lambda_A


        pair_L1loss = self.loss_G_L1
        if self.opt.with_D_PB:
            pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN
            if self.opt.with_D_PP:
                pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN
                pair_GANloss = pair_GANloss / 2
        else:
            if self.opt.with_D_PP:
                pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN

        if self.opt.with_D_PB or self.opt.with_D_PP:
            pair_loss = pair_L1loss + pair_GANloss
        else:
            pair_loss = pair_L1loss

        pair_loss.backward()

        self.pair_L1loss = pair_L1loss.item()
        if self.opt.with_D_PB or self.opt.with_D_PP:
            self.pair_GANloss = pair_GANloss.item()

文字表述就是:
1.分别计算分类器生成的D_PP,D_PB(链上文2.2)的分类损失(生成目标为了混淆分类器,理想值应为True),分别记做loss_G_GAN_PB、loss_G_GAN_PP

2.(l1_plus_perL1)将目标图片与生成图片做l1_plus_perL1的损失函数计算。来看L1_plus_perceptualLoss的具体代码:
首先是该loss层的forward函数:

    def forward(self, inputs, targets):
        if self.lambda_L1 == 0 and self.lambda_perceptual == 0:
            return torch.zeros(1).cuda(), torch.zeros(1), torch.zeros(1)
        # normal L1
        loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1

        # perceptual L1
        mean = torch.FloatTensor(3)
        mean[0] = 0.485
        mean[1] = 0.456
        mean[2] = 0.406
        mean = mean.resize(1, 3, 1, 1).cuda()

        std = torch.FloatTensor(3)
        std[0] = 0.229
        std[1] = 0.224
        std[2] = 0.225
        std = std.resize(1, 3, 1, 1).cuda()

        fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1]
        fake_p2_norm = (fake_p2_norm - mean)/std

        input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1]
        input_p2_norm = (input_p2_norm - mean)/std


        fake_p2_norm = self.vgg_submodel(fake_p2_norm)
        input_p2_norm = self.vgg_submodel(input_p2_norm)
        input_p2_norm_no_grad = input_p2_norm.detach()

        if self.percep_is_l1 == 1:
            # use l1 for perceptual loss
            loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual
        else:
            # use l2 for perceptual loss
            loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual

        loss = loss_l1 + loss_perceptual

        return loss, loss_l1, loss_perceptual

l1_plus_perL1包括两种loss:一种是普通的L1 loss,即直接将input和target做L1 loss,记做loss_l1。
另一种 loss_perceptual的计算过程如下:(1)对input和target分别做normalize,其实就是将他们从[-1,1]的范围变到[0,1],然后减去mean再除以方差std(2)将normalize后的input和target送给vgg网络得到输出fake_p2_norm,和input_p2_norm_no_grad(3)将两个输出做L1 loss得到loss_perceptual
loss_perceptual是为了让图片更加平滑和自然,引入论文原文:


L1_percetual.png

将两种loss相加就得到了最后的loss。
回到backward_G(),三种loss分别记为loss,loss_originL1,loss_perceptual
3.计算总loss
链接原文公式:


full_loss.png

losscombl1.png

在上面的代码中,4式中的α为2,也就是Lcomb除以2之后加上Lgan为总loss。
最后调用总loss.backward()跟新参数

3.3给两个D网络调参(链上文2.2节)

在上面更新了一次G网络之后,更新DG_ratio次分类网络D_PP和D_PB

3.3.1 给D_PP网络调参

        if self.opt.with_D_PP:
            for i in range(self.opt.DG_ratio):
                self.optimizer_D_PP.zero_grad()
                self.backward_D_PP()
                self.optimizer_D_PP.step()
    def backward_D_PP(self):
        real_PP = torch.cat((self.input_P2, self.input_P1), 1)
        # fake_PP = self.fake_PP_pool.query(torch.cat((self.fake_p2, self.input_P1), 1))
        fake_PP = self.fake_PP_pool.query( torch.cat((self.fake_p2, self.input_P1), 1).data )
        loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP)
        self.loss_D_PP = loss_D_PP.item()
   def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

总结步骤如下:
1、将生成图片fake_p2和输入原图片input_P1传给fake_PP_pool.query函数,这个query函数的代码如下:

    def query(self, images):
        if self.pool_size == 0:
            return Variable(images)
        return_images = []
        for image in images:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

这在干啥咱也不知道咱也不敢问,根据默认的配置的话跑的话是不超过50张图时将fake_p2和input_p1拼接起来返回,超过了就是取之前训练的前49张的某张图片来跟当前图片交换。
2、将输入原图片和目标图片拼接起来得到real_PP,拿D_PP网络去预测real_PP,计算预测结果与理想结果(TRUE)之间的loss,记为loss_D_real。拿D_PP网络去预测fake_PP,计算预测结果与理想结果(FALSE)之间的loss,记为loss_D_fake。这里提醒一下D_PP网络用于预测两张图是否包含同一个人。
3.总loss loss_D= (loss_D_real + loss_D_fake) * 0.5,loss_D.backward()更新参数

3.3.2给D_PB网络调参

 def backward_D_PB(self):
        real_PB = torch.cat((self.input_P2, self.input_BP2), 1)
        # fake_PB = self.fake_PB_pool.query(torch.cat((self.fake_p2, self.input_BP2), 1))
        fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2), 1).data )
        loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB)
        self.loss_D_PB = loss_D_PB.item()

跟D_PP类似,只不过real_PB拼接的是目标图片和目标姿势,fake_PB拼接的是生成图片和目标姿势。提醒一下D_PB用于判断图中的人的姿势是否为目标姿势。

四、总结

本文主要讲了训练逻辑,笔者觉得弄懂了训练代码,看测试代码就简单多了,就不再在文里分析了。

你可能感兴趣的:(Pose-Transfer代码阅读笔记)