CycleGAN(四)inference过程与model定义

背景:我们需要搞懂cycleGAN如何对已有图片进行inference

目录

一、嵌套位置

1.1 调用位置

1.2 inference调用的函数

二、前馈运算

2.1 forward

2.2 实验结果及解释

三、模型

3.1 模型定义

3.2 定义loss

3.3 模型结构


一、嵌套位置

1.1 调用位置

CycleGAN(四)inference过程与model定义_第1张图片

test.py之中,很容易看到调用inference的部分

    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()     # get image paths

1.2 inference调用的函数

CycleGAN(四)inference过程与model定义_第2张图片

二、前馈运算

2.1 forward

    def forward(self):
        """Run forward pass; called by both functions  and ."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

 CycleGAN(四)inference过程与model定义_第3张图片

  • real_A输入netG_A 生成 fake_B
  • fake_B输入netG_B生成fake_A
  • real_B输入netG_B生成fake_A
  • fake_A输入netG_A生成rec_B
  • 即fake就是根据real通过生成器G生成的
  • rec就是re-cycle,就是A通过两次生成器返回的A

2.2 实验结果及解释

这也给实验结果一定的解释:

  • 训练时模型是从trainA是正常布料,trainB是棉布料,测试时testA是正常,testB是撕裂布料
  • 即训练: 正常——棉   ,模型是从正常布料到棉布料的迁移
  • 测试:正常——破裂

CycleGAN(四)inference过程与model定义_第4张图片  CycleGAN(四)inference过程与model定义_第5张图片  CycleGAN(四)inference过程与model定义_第6张图片

分别是realA,fakeA,recA

CycleGAN(四)inference过程与model定义_第7张图片  CycleGAN(四)inference过程与model定义_第8张图片  CycleGAN(四)inference过程与model定义_第9张图片

分别是realB,fakeB,recB

三、模型

3.1 模型定义

base_model.py与cycle_gan_model.py之中定义了模型,loss等各种信息。

CycleGAN(四)inference过程与model定义_第10张图片

3.2 定义loss

具体参见:

CycleGAN(五)loss理解及更改与实验

初始化时定义了几种loss的名称,后面定义了backward_D_basic和backward_G

CycleGAN(四)inference过程与model定义_第11张图片

3.3 模型结构

CycleGAN(四)inference过程与model定义_第12张图片

    def forward(self):
        """Run forward pass; called by both functions  and ."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

CycleGAN(四)inference过程与model定义_第13张图片

你可能感兴趣的:(机器学习,PyTorch,python,image2image,GAN)