SEAN代码(1)

代码地址
首先定义一个trainer。

trainer = Pix2PixTrainer(opt)

在Pix2PixTrainer内部,首先定义Pix2PixModel模型。

self.pix2pix_model = Pix2PixModel(opt)

在Pix2PixModel内部定义生成器,判别器。

self.netG, self.netD, self.netE = self.initialize_networks(opt)

在initialize_networks内部定义功能。

netG = networks.define_G(opt)
netD = networks.define_D(opt) if opt.isTrain else None
netE = networks.define_E(opt) if opt.use_vae else None

首先看生成器:

def define_G(opt):
    netG_cls = find_network_using_name(opt.netG, 'generator')#netG=spade
    return create_network(netG_cls, opt)

输入的参数是opt.netG,在option中对应的是spade。在find_network_using_name中:

def find_network_using_name(target_network_name, filename):#spade,generator
    target_class_name = target_network_name + filename#spadegenerator
    module_name = 'models.networks.' + filename#models.networks.generator
    network = util.find_class_in_module(target_class_name, module_name)#
    assert issubclass(network, BaseNetwork), \
        "Class %s should be a subclass of BaseNetwork" % network

    return network

根据target_network_name和对应的filename输入到find_class_in_module中:

def find_class_in_module(target_cls_name, module):
    target_cls_name = target_cls_name.replace('_', '').lower()#spadegenerator
    clslib = importlib.import_module(module)#import_module()返回指定的包或模块
    cls = None
    for name, clsobj in clslib.__dict__.items():
        if name.lower() == target_cls_name:
            cls = clsobj

    if cls is None:
        print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
        exit(0)

    return cls

我们通过import_module函数载入module这个模块,module对应的是models.networks.generator。即clslib 就是generator文件中的类。我们遍历clslib的字典,如果name等于spadegenerator,令cls = clsobj。
即network等于cls。

network = util.find_class_in_module(target_class_name, module_name)

这里有两个语法问题:
①:导入importlib,调用import_module()方法,根据输入的字符串可以获得模块clslib ,clslib 可以调用models.networks.generator文件下所有的属性和方法。
SEAN代码(1)_第1张图片
在generator内部是:
SEAN代码(1)_第2张图片可以通过clslib.SPADEGenerator来实例化SPADEGenerator,然后再调用SPADEGenerator内部的方法。
举个例子:新建三个文件。
SEAN代码(1)_第3张图片
train:
SEAN代码(1)_第4张图片
用不到test,在tt文件内部中导入train中的类s。
SEAN代码(1)_第5张图片
因为是同级目录,直接导入字符串train即可,如果不在同级目录,需要导入前一个目录。
接着a就会变成一个module,即train。然后实例化train文件夹下的类s。最后调用类s的方法kill和qqq。
输出:
在这里插入图片描述
②: dict,该属性可以用类名或者类的实例对象来调用,用**类名直接调用 dict,会输出该由类中所有类属性组成的字典;**而使用类的实例对象调用 dict,会输出由类中所有实例属性组成的字典。
参考
这里SPADEGenerator继承了BaseNetwork,对于具有继承关系的父类和子类来说,父类有自己的 dict,同样子类也有自己的 dict,它不会包含父类的 dict
例子:按上面的例子,a是一个module,查看a的__dict__:
在这里插入图片描述
输出:
SEAN代码(1)_第6张图片
回到代码中:我们输出的network就是类
下一步我们创建网络:在这里插入图片描述
SEAN代码(1)_第7张图片
cls对应的是SPADEGenerator网络。
在SPADE中:

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock
from models.networks.architecture import Zencoder

class SPADEGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
        parser.add_argument('--num_upsampling_layers',
                            choices=('normal', 'more', 'most'), default='normal',
                            help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")

        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        self.Zencoder = Zencoder(3, 512)


        self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0')

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0')
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1')

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0')
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1')
        self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2')
        self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False)

        final_nc = nf

        if opt.num_upsampling_layers == 'most':
            self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4')
            final_nc = nf // 2

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)
        #self.up = nn.Upsample(scale_factor=2, mode='bilinear')
    def compute_latent_vector_size(self, opt):
        if opt.num_upsampling_layers == 'normal':#默认
            num_up_layers = 5
        elif opt.num_upsampling_layers == 'more':
            num_up_layers = 6
        elif opt.num_upsampling_layers == 'most':
            num_up_layers = 7
        else:
            raise ValueError('opt.num_upsampling_layers [%s] not recognized' %
                             opt.num_upsampling_layers)

        sw = opt.crop_size // (2**num_up_layers)#256//32=16
        sh = round(sw / opt.aspect_ratio)#8

        return sw, sh

    def forward(self, input, rgb_img, obj_dic=None):
        seg = input
        x = F.interpolate(seg, size=(self.sh, self.sw))#(16,16)
        x = self.fc(x)#(b,1024,16,16)

        style_codes = self.Zencoder(input=rgb_img, segmap=seg)
        x = self.head_0(x, seg, style_codes, obj_dic=obj_dic)

        x = self.up(x)
        x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic)

        if self.opt.num_upsampling_layers == 'more' or \
           self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.up(x)
        x = self.up_0(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_1(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_2(x, seg, style_codes, obj_dic=obj_dic)
        x = self.up(x)
        x = self.up_3(x, seg, style_codes,  obj_dic=obj_dic)

        # if self.opt.num_upsampling_layers == 'most':
        #     x = self.up(x)
        #     x= self.up_4(x, seg, style_codes,  obj_dic=obj_dic)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)
        return x

首先计算潜在空间向量的大小:
SEAN代码(1)_第8张图片
接着计算style matrixST。对应文章的 :
SEAN代码(1)_第9张图片
在代码中:通过卷积,下采样,下采样,上采样,卷积。输出一个通道为512的向量。
SEAN代码(1)_第10张图片
接着是连续的四个上采样模块:
在这里插入图片描述
对应于:
SEAN代码(1)_第11张图片
在SPADEResnetBlock内部:使用ACE类定义了SEAN块。
SEAN代码(1)_第12张图片
在ACE内部定义了归一化的参数和噪声等。
SEAN代码(1)_第13张图片
下面设计python正则表达式,没学过,下去补。只能先用debug获得结果。
在这里插入图片描述
这里使用SynchronizedBatchNorm2d进行归一化:
在这里插入图片描述
γ和β通过卷积获得:
SEAN代码(1)_第14张图片
执行完上采样的四个SEAN块之后,最后进过一个卷积输出合成图像。这就是整个network的流程。
生成器打印参数:
在这里插入图片描述
接着是判别器:
按照生成器的逻辑,target_class_name=multiscalediscriminator,module_name=models.networks.discriminator
然后我们导入判别器模块。
SEAN代码(1)_第15张图片
在多尺度判别器内部:创建两个single_discriminator。
SEAN代码(1)_第16张图片
SEAN代码(1)_第17张图片
在单个判别器内部定义参数:SEAN代码(1)_第18张图片
定义判别器的输入:将label通道和RGB图片拼接后输入。
SEAN代码(1)_第19张图片
接着经过一个4x4大小步长为2的卷积,再经过两个步长为2的卷积,最后再经过输出通道为1,步长为1的卷积。将每一个卷积都注册到模型中。
在这里插入图片描述
即判别器由五个卷积组成。
将单个判别器注册到判别器中。注册两次,这样盘比起由10个卷积组成,且都有对应的吗名称。
在这里插入图片描述

MultiscaleDiscriminator(
  (discriminator_0): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
  (discriminator_1): NLayerDiscriminator(
    (model0): Sequential(
      (0): Conv2d(16, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model1): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model2): Sequential(
      (0): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): LeakyReLU(negative_slope=0.2)
    )
    (model3): Sequential(
      (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    )
  )
)

这样生成器判别器都构造完毕,netE为空。

你可能感兴趣的:(paper代码,生成对抗网络,人工智能,神经网络,pytorch)