DR-GAN代码实现记录

我主要基于此github代码修改所写。其实只是自己修改了dataset的部分。代码使用pytorch1.0。

链接中的代码需要几个.npy文件,作者也没提供,也没有生成这些.npy文件的代码,所以我只有撸起袖子自己写这一部分,再和作者代码对接就行了。

关于DR-Gan的论文解读见here

数据读取

class FaceIdPoseDataset(Dataset):

    #  assume images  as B x C x H x W  numpy array
    def __init__(self, root, transform=None):

        with open(root,'r') as f:
            data = ['./dataset'+ line.strip() for line in f.readlines()]
        self.data = np.random.permutation(data)
        self.transform = transform
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        sample = self.data[idx]
        img = Image.open(sample)
        img = img.convert('RGB')
        if sample.find('frontal') !=-1:
            pose_label = 1
        else :
            pose_label = 0
        id_label = int(sample[22:25])
        img= self.transform(img)
        return [img.float(), id_label, pose_label]

def get_batch(root,batch_size):
    data_set = FaceIdPoseDataset(root,
                                 transform=transforms.Compose([
                                     transforms.Resize((110,110)),
                                     transforms.RandomCrop((96,96)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                 ]))
    dataloader = DataLoader(data_set,batch_size=batch_size,
                            shuffle=True,drop_last=True)  #drop_last is necessary,because last iteration may fail
    return dataloader

实现思路就是自定义的dataset中,使用记录了cfp 数据集文件名字中的txt文件,把文件名字读取存在list中,打乱一下。

另外要注意的是,在FaceIdPoseDataset的构造函数中有transform参数,这是必要的。在getitem中就按照index读取就行了。打开图像的方式一定是要用PIL.image.open,然后记得要convert成RGB格式的(必要)。然后用transform函数数据增强。

 

模型搭建

我仅用了single-image dr-gan做实验。代码和论文中所见一下。而且都是容易懂的代码,看一下判别器的部分。

class Discriminator(nn.Module):
    """
    multi-task CNN for identity and pose classification

    ### init
    Nd : Number of identitiy to classify
    Np : Number of pose to classify

    """

    def __init__(self, Nd, Np, channel_num):
        super(Discriminator, self).__init__()
        convLayers = [
            nn.Conv2d(channel_num, 32, 3, 1, 1, bias=False), # Bxchx96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx64x96x96 -> Bx64x97x97
            nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx128x48x48 -> Bx128x49x49
            nn.Conv2d(128, 128, 3, 2, 0, bias=False), #  Bx128x49x49 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 96, 3, 1, 1, bias=False), #  Bx128x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.Conv2d(96, 192, 3, 1, 1, bias=False), #  Bx96x24x24 -> Bx192x24x24
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx192x24x24 -> Bx192x25x25
            nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.Conv2d(128, 256, 3, 1, 1, bias=False), # Bx128x12x12 -> Bx256x12x12
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ZeroPad2d((0, 1, 0, 1)),                      # Bx256x12x12 -> Bx256x13x13
            nn.Conv2d(256, 256, 3, 2, 0, bias=False),  # Bx256x13x13 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.Conv2d(160, 320, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
            nn.BatchNorm2d(320),
            nn.ELU(),
            nn.AvgPool2d(6, stride=1), #  Bx320x6x6 -> Bx320x1x1
        ]

        self.convLayers = nn.Sequential(*convLayers)
        self.fc = nn.Linear(320, Nd+1+Np)

        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)

    def forward(self, input):
        # 畳み込み -> 平均プーリングの結果 B x 320 x 1 x 1の出力を得る
        x = self.convLayers(input)

        x = x.view(-1, 320)

        # 全結合
        x = self.fc(x) # Bx320 -> B x (Nd+1+Np)

        return x

构造函数里面,依照选卷积后FC层的设计思路,然后初始化参数。因为用了batch norm,所以卷积的bias就可以不要了。

这里值得注意的是,卷积部分有一个zeroPad的操作。用来做啥的呢?仅仅encoder部分设置一样。那么encoder部分为啥要加zeroPad呢?这是为了和decoder部分对应。decoder是反卷积的过程,与encoder是对称结构。因为encoder和D是一样的结构,所以我们直接看看decoder的代码。

class Generator(nn.Module):
    """
    Encoder/Decoder conditional GAN conditioned with pose vector and noise vector

    ### init
    Np : Dimension of pose vector (Corresponds to number of dicrete pose classes of the data)
    Nz : Dimension of noise vector

    """

    def __init__(self, Np, Nz, channel_num):
        super(Generator, self).__init__()
        self.features = []

       G_enc_convLayers=[
        xxxxxxxxx
        '''
        省略encoder的搭建
        '''
        ]
        
        self.G_enc_convLayers = nn.Sequential(*G_enc_convLayers)

        G_dec_convLayers = [
            nn.ConvTranspose2d(320,160, 3,1,1, bias=False), # Bx320x6x6 -> Bx160x6x6
            nn.BatchNorm2d(160),
            nn.ELU(),
            nn.ConvTranspose2d(160, 256, 3,1,1, bias=False), # Bx160x6x6 -> Bx256x6x6
            nn.BatchNorm2d(256),
            nn.ELU(),
            nn.ConvTranspose2d(256, 256, 3,2,0, bias=False), # Bx256x6x6 -> Bx256x13x13
            nn.BatchNorm2d(256),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(256, 128, 3,1,1, bias=False), # Bx256x12x12 -> Bx128x12x12
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 192,  3,1,1, bias=False), # Bx128x12x12 -> Bx192x12x12
            nn.BatchNorm2d(192),
            nn.ELU(),
            nn.ConvTranspose2d(192, 192,  3,2,0, bias=False), # Bx128x12x12 -> Bx192x25x25
            nn.BatchNorm2d(192),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(192, 96,  3,1,1, bias=False), # Bx192x24x24 -> Bx96x24x24
            nn.BatchNorm2d(96),
            nn.ELU(),
            nn.ConvTranspose2d(96, 128,  3,1,1, bias=False), # Bx96x24x24 -> Bx128x24x24
            nn.BatchNorm2d(128),
            nn.ELU(),
            nn.ConvTranspose2d(128, 128,  3,2,0, bias=False), # Bx128x24x24 -> Bx128x49x49
            nn.BatchNorm2d(128),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(128, 64,  3,1,1, bias=False), # Bx128x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,1,1, bias=False), # Bx64x48x48 -> Bx64x48x48
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.ConvTranspose2d(64, 64,  3,2,0, bias=False), # Bx64x48x48 -> Bx64x97x97
            nn.BatchNorm2d(64),
            nn.ELU(),
            Crop([0, 1, 0, 1]),
            nn.ConvTranspose2d(64, 32,  3,1,1, bias=False), # Bx64x96x96 -> Bx32x96x96
            nn.BatchNorm2d(32),
            nn.ELU(),
            nn.ConvTranspose2d(32, channel_num,  3,1,1, bias=False), # Bx32x96x96 -> Bxchx96x96
            nn.Tanh(),
        ]

        self.G_dec_convLayers = nn.Sequential(*G_dec_convLayers)

        self.G_dec_fc = nn.Linear(320+Np+Nz, 320*6*6)

        # 重みは全て N(0, 0.02) で初期化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)

            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)



    def forward(self, input, pose, noise):

        x = self.G_enc_convLayers(input) # Bxchx96x96 -> Bx320x1x1

        x = x.view(-1,320)

        self.features = x

        x = torch.cat([x, pose, noise], 1)  # Bx320 -> B x (320+Np+Nz)

        x = self.G_dec_fc(x) # B x (320+Np+Nz) -> B x (320x6x6)

        x = x.view(-1, 320, 6, 6) # B x (320x6x6) -> B x 320 x 6 x 6

        x = self.G_dec_convLayers(x) #  B x 320 x 6 x 6 -> Bxchx96x96

        return x

看到G_dec_convLayers 这个列表中,有crop这个类,是用来截去特征图的,丢弃map的最后一行和最后一列。为啥要这样做呢?这是因为nn.convTranspose2d这个类的forward输出的结果导致的。

我们想想,转置卷积(反卷积)的目的是扩大特征图,扩大的倍数是往往是两倍,那么就是说扩大的特征图总是2的倍数。

可是卷积过程就不是这样了。比如说,map尺寸为15或者14,stride为2的卷积,且用same方式,都能得到7的map。

那对应的转置卷积,到底要输出14的map,还是输出15的map呢?答案是15,但如果此时对应的卷积过程的输入时14呢?

我们上面看到的decoder过程恰恰就属于上面的情况,所以crop一下,去掉一个像素单位的边。

损失函数

loss_criterion = nn.CrossEntropyLoss()
loss_criterion_gan = nn.BCEWithLogitsLoss()

一个是标准交叉熵,另一个是二类交叉熵。

            batch_ones_label = t.ones(conf.batch_size)  # 真偽判別用のラベル
            batch_zeros_label = t.zeros(conf.batch_size)

            fixed_noise = t.FloatTensor(
                np.random.uniform(-1, 1, (conf.batch_size, conf.nz)))
            tmp = t.LongTensor(np.random.randint(conf.np, size=conf.batch_size))
            pose_code = one_hot(tmp, conf.np)  # Condition 付に使用
            pose_code_label = t.LongTensor(tmp)  # CrossEntropy 誤差に使用

前两行分别是真实样本和假样本的真假标签。

fixed_noise 就是论文中提到的 noise z。

tmp只是为了获得下面的pose code,是独热编码,就是把离散的pose当做分类问题。

下面是D的损失函数:

    L_id    = loss_criterion(real_output[:, :Nd], batch_id_label)
    L_gan   = loss_criterion_gan(real_output[:, Nd], batch_ones_label) + loss_criterion_gan(syn_output[:, Nd], batch_zeros_label)
    L_pose  = loss_criterion(real_output[:, Nd+1:], batch_pose_label)

    d_loss = L_gan + L_id + L_pose

real_output就是真实样本送进D的输出,是一个Nd+1+Np维向量。D是要把真实样本分到前Nd个位置代表的类中,这就是第一行为啥这么写。和论文中损失函数也是对应的。

D还要把假样本分到第Nd+1的位置对应的类上,就是第二行做的事情,将真样本都当做正样本,把假样本都当做负样本。是二分类问题,所有使用的是二类交叉熵。

D还要把真样本的pose分类对,就是第三行做的事情。

这么一解释,G的损失函数相信看代码就明白了。

 

代码工程下载地址:here

查看reademe文件

你可能感兴趣的:(Pytorch)