2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp

单侧无监督域适应的几何一致生成对抗网络

1.摘要:

无监督域映射旨在学习一个函数G_{xy}翻译域X图像到域Y图像,在配对样本缺少的情况下。在没有配对数据情况下,发现最优的G_{xy}是一个病态的问题,因此获得合理的解需要适合的约束。尽管一些著名的(prominent)约束,例如循环一致性(cycle consistency), 距离保留(distance preservation)成功地约束解空间,但是他们忽视了图像的特殊属性——简单的几何变换不会改变图像的语义结构。基于这些特殊的属性,作者提出一个几何一致性生成对抗网络(geometry-consistent generative adversarial networks -GcGAN), 进行单侧无监督域适应。GcGAN把原始图像以及对应几何变换图像作为模型的输入,并在新域中生成两个图像,并加上相应的几何一致性约束。几何一致性约束减少了可能解的空间并保证了在搜索空间中正确解。与基线模型GAN,以及最新的方法CycleGAN, DistanceGAN进行定性,定量的对比证明我们的方法的有效性。

 

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第1张图片

该图片可视化CycleGAN, DistanceGAN , GcGAN之间的不同。

文章中,作者采用了两种常用的几何变换:顺时针旋转90度,垂直翻转。

2.结合代码讲论文:

论文中实验最优结果所使用的约束:对抗约束,几何一致性约束,循环一致性约束(GcGAN-rot + Cycle)

对抗约束:

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第2张图片


循环一致性约束:

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第3张图片


几何一致性约束:

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第4张图片


距离约束 Distance constraint:

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第5张图片

 


代码分析:

对代码链接中gc_cycle_gan_model.py文件中的代码进行分析:

模型的主要优化过程:输入数据,优化生成器,优化目标域判别器,优化源域判别器。

  def optimize_parameters(self):
        # forward
        self.forward()
        # G_AB
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_B and D_gc_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()

输入数据


    def forward(self):
        input_A = self.input_A.clone()
        input_B = self.input_B.clone()

        self.real_A = self.input_A
        self.real_B = self.input_B

        size = self.opt.fineSize

        if self.opt.geometry == 'rot':
          self.real_gc_A = self.rot90(input_A, 0)
          self.real_gc_B = self.rot90(input_B, 0)
        elif self.opt.geometry == 'vf':
          inv_idx = torch.arange(size-1, -1, -1).long().cuda()
          self.real_gc_A = torch.index_select(input_A, 2, inv_idx)
          self.real_gc_B = torch.index_select(input_B, 2, inv_idx)
        else:
          raise ValueError("Geometry transformation function [%s] not recognized." % self.opt.geometry)

输入源域图像self.input_A,目标域图像self.input_B,根据事先设置的几何变换参数self.opt.gemetry进行顺时针旋转90度变换(在代码中体现,self.real_gc_A = self.rot90(input_A, 0),如果是逆时针旋转90度,则self.rot()函数第二个参数设置为1)。


优化生成器:


    def backward_G(self):
        # adversariasl loss
        fake_B = self.netG_AB.forward(self.real_A)
        pred_fake = self.netD_B.forward(fake_B)
        loss_G_AB = self.criterionGAN(pred_fake, True)*self.opt.lambda_G

        fake_gc_B = self.netG_AB.forward(self.real_gc_A)
        pred_fake = self.netD_gc_B.forward(fake_gc_B)
        loss_G_gc_AB = self.criterionGAN(pred_fake, True)*self.opt.lambda_G

        fake_A = self.netG_BA.forward(self.real_B)
        pred_fake = self.netD_A.forward(fake_A)
        loss_G_AB += self.criterionGAN(pred_fake, True)*self.opt.lambda_G

        fake_gc_A = self.netG_BA.forward(self.real_gc_B)
        pred_fake = self.netD_gc_A.forward(fake_gc_A)
        loss_G_gc_AB += self.criterionGAN(pred_fake, True)*self.opt.lambda_G

        if self.opt.geometry == 'rot':
            loss_gc = self.get_gc_rot_loss(fake_B, fake_gc_B, 0)
            loss_gc += self.get_gc_rot_loss(fake_A, fake_gc_A, 0)
        elif self.opt.geometry == 'vf':
            loss_gc = self.get_gc_vf_loss(fake_B, fake_gc_B)
            loss_gc += self.get_gc_vf_loss(fake_A, fake_gc_A)

        if self.opt.identity > 0:
            # G_AB should be identity if real_B is fed.
            idt_A = self.netG_AB(self.real_B)
            loss_idt = self.criterionIdt(idt_A, self.real_B) * self.opt.lambda_AB * self.opt.identity
            idt_gc_A = self.netG_AB(self.real_gc_B)
            loss_idt_gc = self.criterionIdt(idt_gc_A, self.real_gc_B) * self.opt.lambda_AB * self.opt.identity

            idt_B = self.netG_BA(self.real_A)
            loss_idt += self.criterionIdt(idt_B, self.real_A) * self.opt.lambda_AB * self.opt.identity
            idt_gc_B = self.netG_BA(self.real_gc_A)
            loss_idt_gc += self.criterionIdt(idt_gc_B, self.real_gc_A) * self.opt.lambda_AB * self.opt.identity

            self.idt_A = idt_A.data
            self.idt_gc_A = idt_gc_A.data
            self.loss_idt = loss_idt.item()
            self.loss_idt_gc = loss_idt_gc.item()
        else:
            loss_idt = 0
            loss_idt_gc = 0
            self.loss_idt = 0
            self.loss_idt_gc = 0

        rec_A = self.netG_BA(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * self.opt.lambda_AB
        rec_gc_A = self.netG_BA(fake_gc_B)
        loss_cycle_A += self.criterionCycle(rec_gc_A, self.real_gc_A) * self.opt.lambda_AB
        
        rec_B = self.netG_AB(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * self.opt.lambda_AB
        rec_gc_B = self.netG_BA(fake_gc_A)
        loss_cycle_B += self.criterionCycle(rec_gc_B, self.real_gc_B) * self.opt.lambda_AB

        loss_G = loss_G_AB + loss_G_gc_AB + loss_gc + loss_idt + loss_idt_gc + loss_cycle_A + loss_cycle_B

        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_gc_B = fake_gc_B.data
        self.fake_A = fake_A.data
        self.fake_gc_A = fake_gc_A.data

        self.loss_G_AB = loss_G_AB.item()
        self.loss_G_gc_AB= loss_G_gc_AB.item()
        self.loss_gc = loss_gc.item()

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第6张图片

代码中self.real_A,fake_B,  self.real_gc_A, fake_gc_B分别对应图中x_{i},y_{i}^{'}, \tilde{x_{i}},\tilde{y}_{i}^{'}

self.netG_AB 对应G_{XY}, G_{\tilde{X}\tilde{Y}}

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第7张图片

几何一致性损失约束,体现在:

      if self.opt.geometry == 'rot':
            loss_gc = self.get_gc_rot_loss(fake_B, fake_gc_B, 0)
            loss_gc += self.get_gc_rot_loss(fake_A, fake_gc_A, 0)
        elif self.opt.geometry == 'vf':
            loss_gc = self.get_gc_vf_loss(fake_B, fake_gc_B)
            loss_gc += self.get_gc_vf_loss(fake_A, fake_gc_A)
    def get_gc_rot_loss(self, AB, AB_gc, direction):
        loss_gc = 0.0

        if direction == 0:
          AB_gt = self.rot90(AB_gc.clone().detach(), 1)
          loss_gc = self.criterionGc(AB, AB_gt)
          AB_gc_gt = self.rot90(AB.clone().detach(), 0)
          loss_gc += self.criterionGc(AB_gc, AB_gc_gt)
        else:
          AB_gt = self.rot90(AB_gc.clone().detach(), 0)
          loss_gc = self.criterionGc(AB, AB_gt)
          AB_gc_gt = self.rot90(AB.clone().detach(), 1)
          loss_gc += self.criterionGc(AB_gc, AB_gc_gt)

        loss_gc = loss_gc*self.opt.lambda_AB*self.opt.lambda_gc
        #loss_gc = loss_gc*self.opt.lambda_AB
        return loss_gc

y_{i}^{'}f^{-1}(\tilde{y}_{i}^{'})进行L1范数计算流程:

在get_gc_rot_loss()损失中:

y_{i}^{'}, \tilde{y}_{i}^{'},f^{-1}(\tilde{y}_{i}^{'}),f(y_{i}^{'}),分别对应与AB,AB_gc, AB_gt,AB_gc_gt。


作者使用几何一致性损失 + 旋转几何变换 + 循环一致性损失在城市景观数据集上去的最好的结果。要搞清楚作者在合成图像重建回原始图像过程中是否使用几何一致性损失?

根据代码以及论文的题目中“单侧无监督域适应”,所以作者只在x_{i}翻译到y_{i}^{'}或者y_{j}翻译到x_{j}^{'}过程中使用了几何一致性损失。

2019-CVPR-Geometry-Consistent Generative Adversarial Networks for One-Sided Unsupervised Domain Mapp_第8张图片

代码链接:

【1】https://github.com/hufu6371/GcGAN

 

作者怎么评估翻译图像的质量?

对于图像标签图翻译到图像的过程,作者认为高质量的翻译图像应该产生定性的分割结果,就像真实图像的分割结果一样。

因此作者使用pixel accuracy, class accuracy, mean IOU 评估翻译图像的分割结果,使用pix2pix 提供的预训练模型FCN-8s分割合成图像。

 

你可能感兴趣的:(域适应)