我主要基于此github代码修改所写。其实只是自己修改了dataset的部分。代码使用pytorch1.0。
链接中的代码需要几个.npy文件,作者也没提供,也没有生成这些.npy文件的代码,所以我只有撸起袖子自己写这一部分,再和作者代码对接就行了。
关于DR-Gan的论文解读见here
class FaceIdPoseDataset(Dataset):
# assume images as B x C x H x W numpy array
def __init__(self, root, transform=None):
with open(root,'r') as f:
data = ['./dataset'+ line.strip() for line in f.readlines()]
self.data = np.random.permutation(data)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
img = Image.open(sample)
img = img.convert('RGB')
if sample.find('frontal') !=-1:
pose_label = 1
else :
pose_label = 0
id_label = int(sample[22:25])
img= self.transform(img)
return [img.float(), id_label, pose_label]
def get_batch(root,batch_size):
data_set = FaceIdPoseDataset(root,
transform=transforms.Compose([
transforms.Resize((110,110)),
transforms.RandomCrop((96,96)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
dataloader = DataLoader(data_set,batch_size=batch_size,
shuffle=True,drop_last=True) #drop_last is necessary,because last iteration may fail
return dataloader
实现思路就是自定义的dataset中,使用记录了cfp 数据集文件名字中的txt文件,把文件名字读取存在list中,打乱一下。
另外要注意的是,在FaceIdPoseDataset的构造函数中有transform参数,这是必要的。在getitem中就按照index读取就行了。打开图像的方式一定是要用PIL.image.open,然后记得要convert成RGB格式的(必要)。然后用transform函数数据增强。
我仅用了single-image dr-gan做实验。代码和论文中所见一下。而且都是容易懂的代码,看一下判别器的部分。
class Discriminator(nn.Module):
"""
multi-task CNN for identity and pose classification
### init
Nd : Number of identitiy to classify
Np : Number of pose to classify
"""
def __init__(self, Nd, Np, channel_num):
super(Discriminator, self).__init__()
convLayers = [
nn.Conv2d(channel_num, 32, 3, 1, 1, bias=False), # Bxchx96x96 -> Bx32x96x96
nn.BatchNorm2d(32),
nn.ELU(),
nn.Conv2d(32, 64, 3, 1, 1, bias=False), # Bx32x96x96 -> Bx64x96x96
nn.BatchNorm2d(64),
nn.ELU(),
nn.ZeroPad2d((0, 1, 0, 1)), # Bx64x96x96 -> Bx64x97x97
nn.Conv2d(64, 64, 3, 2, 0, bias=False), # Bx64x97x97 -> Bx64x48x48
nn.BatchNorm2d(64),
nn.ELU(),
nn.Conv2d(64, 64, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx64x48x48
nn.BatchNorm2d(64),
nn.ELU(),
nn.Conv2d(64, 128, 3, 1, 1, bias=False), # Bx64x48x48 -> Bx128x48x48
nn.BatchNorm2d(128),
nn.ELU(),
nn.ZeroPad2d((0, 1, 0, 1)), # Bx128x48x48 -> Bx128x49x49
nn.Conv2d(128, 128, 3, 2, 0, bias=False), # Bx128x49x49 -> Bx128x24x24
nn.BatchNorm2d(128),
nn.ELU(),
nn.Conv2d(128, 96, 3, 1, 1, bias=False), # Bx128x24x24 -> Bx96x24x24
nn.BatchNorm2d(96),
nn.ELU(),
nn.Conv2d(96, 192, 3, 1, 1, bias=False), # Bx96x24x24 -> Bx192x24x24
nn.BatchNorm2d(192),
nn.ELU(),
nn.ZeroPad2d((0, 1, 0, 1)), # Bx192x24x24 -> Bx192x25x25
nn.Conv2d(192, 192, 3, 2, 0, bias=False), # Bx192x25x25 -> Bx192x12x12
nn.BatchNorm2d(192),
nn.ELU(),
nn.Conv2d(192, 128, 3, 1, 1, bias=False), # Bx192x12x12 -> Bx128x12x12
nn.BatchNorm2d(128),
nn.ELU(),
nn.Conv2d(128, 256, 3, 1, 1, bias=False), # Bx128x12x12 -> Bx256x12x12
nn.BatchNorm2d(256),
nn.ELU(),
nn.ZeroPad2d((0, 1, 0, 1)), # Bx256x12x12 -> Bx256x13x13
nn.Conv2d(256, 256, 3, 2, 0, bias=False), # Bx256x13x13 -> Bx256x6x6
nn.BatchNorm2d(256),
nn.ELU(),
nn.Conv2d(256, 160, 3, 1, 1, bias=False), # Bx256x6x6 -> Bx160x6x6
nn.BatchNorm2d(160),
nn.ELU(),
nn.Conv2d(160, 320, 3, 1, 1, bias=False), # Bx160x6x6 -> Bx320x6x6
nn.BatchNorm2d(320),
nn.ELU(),
nn.AvgPool2d(6, stride=1), # Bx320x6x6 -> Bx320x1x1
]
self.convLayers = nn.Sequential(*convLayers)
self.fc = nn.Linear(320, Nd+1+Np)
# 重みは全て N(0, 0.02) で初期化
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
def forward(self, input):
# 畳み込み -> 平均プーリングの結果 B x 320 x 1 x 1の出力を得る
x = self.convLayers(input)
x = x.view(-1, 320)
# 全結合
x = self.fc(x) # Bx320 -> B x (Nd+1+Np)
return x
构造函数里面,依照选卷积后FC层的设计思路,然后初始化参数。因为用了batch norm,所以卷积的bias就可以不要了。
这里值得注意的是,卷积部分有一个zeroPad的操作。用来做啥的呢?仅仅encoder部分设置一样。那么encoder部分为啥要加zeroPad呢?这是为了和decoder部分对应。decoder是反卷积的过程,与encoder是对称结构。因为encoder和D是一样的结构,所以我们直接看看decoder的代码。
class Generator(nn.Module):
"""
Encoder/Decoder conditional GAN conditioned with pose vector and noise vector
### init
Np : Dimension of pose vector (Corresponds to number of dicrete pose classes of the data)
Nz : Dimension of noise vector
"""
def __init__(self, Np, Nz, channel_num):
super(Generator, self).__init__()
self.features = []
G_enc_convLayers=[
xxxxxxxxx
'''
省略encoder的搭建
'''
]
self.G_enc_convLayers = nn.Sequential(*G_enc_convLayers)
G_dec_convLayers = [
nn.ConvTranspose2d(320,160, 3,1,1, bias=False), # Bx320x6x6 -> Bx160x6x6
nn.BatchNorm2d(160),
nn.ELU(),
nn.ConvTranspose2d(160, 256, 3,1,1, bias=False), # Bx160x6x6 -> Bx256x6x6
nn.BatchNorm2d(256),
nn.ELU(),
nn.ConvTranspose2d(256, 256, 3,2,0, bias=False), # Bx256x6x6 -> Bx256x13x13
nn.BatchNorm2d(256),
nn.ELU(),
Crop([0, 1, 0, 1]),
nn.ConvTranspose2d(256, 128, 3,1,1, bias=False), # Bx256x12x12 -> Bx128x12x12
nn.BatchNorm2d(128),
nn.ELU(),
nn.ConvTranspose2d(128, 192, 3,1,1, bias=False), # Bx128x12x12 -> Bx192x12x12
nn.BatchNorm2d(192),
nn.ELU(),
nn.ConvTranspose2d(192, 192, 3,2,0, bias=False), # Bx128x12x12 -> Bx192x25x25
nn.BatchNorm2d(192),
nn.ELU(),
Crop([0, 1, 0, 1]),
nn.ConvTranspose2d(192, 96, 3,1,1, bias=False), # Bx192x24x24 -> Bx96x24x24
nn.BatchNorm2d(96),
nn.ELU(),
nn.ConvTranspose2d(96, 128, 3,1,1, bias=False), # Bx96x24x24 -> Bx128x24x24
nn.BatchNorm2d(128),
nn.ELU(),
nn.ConvTranspose2d(128, 128, 3,2,0, bias=False), # Bx128x24x24 -> Bx128x49x49
nn.BatchNorm2d(128),
nn.ELU(),
Crop([0, 1, 0, 1]),
nn.ConvTranspose2d(128, 64, 3,1,1, bias=False), # Bx128x48x48 -> Bx64x48x48
nn.BatchNorm2d(64),
nn.ELU(),
nn.ConvTranspose2d(64, 64, 3,1,1, bias=False), # Bx64x48x48 -> Bx64x48x48
nn.BatchNorm2d(64),
nn.ELU(),
nn.ConvTranspose2d(64, 64, 3,2,0, bias=False), # Bx64x48x48 -> Bx64x97x97
nn.BatchNorm2d(64),
nn.ELU(),
Crop([0, 1, 0, 1]),
nn.ConvTranspose2d(64, 32, 3,1,1, bias=False), # Bx64x96x96 -> Bx32x96x96
nn.BatchNorm2d(32),
nn.ELU(),
nn.ConvTranspose2d(32, channel_num, 3,1,1, bias=False), # Bx32x96x96 -> Bxchx96x96
nn.Tanh(),
]
self.G_dec_convLayers = nn.Sequential(*G_dec_convLayers)
self.G_dec_fc = nn.Linear(320+Np+Nz, 320*6*6)
# 重みは全て N(0, 0.02) で初期化
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
def forward(self, input, pose, noise):
x = self.G_enc_convLayers(input) # Bxchx96x96 -> Bx320x1x1
x = x.view(-1,320)
self.features = x
x = torch.cat([x, pose, noise], 1) # Bx320 -> B x (320+Np+Nz)
x = self.G_dec_fc(x) # B x (320+Np+Nz) -> B x (320x6x6)
x = x.view(-1, 320, 6, 6) # B x (320x6x6) -> B x 320 x 6 x 6
x = self.G_dec_convLayers(x) # B x 320 x 6 x 6 -> Bxchx96x96
return x
看到G_dec_convLayers 这个列表中,有crop这个类,是用来截去特征图的,丢弃map的最后一行和最后一列。为啥要这样做呢?这是因为nn.convTranspose2d这个类的forward输出的结果导致的。
我们想想,转置卷积(反卷积)的目的是扩大特征图,扩大的倍数是往往是两倍,那么就是说扩大的特征图总是2的倍数。
可是卷积过程就不是这样了。比如说,map尺寸为15或者14,stride为2的卷积,且用same方式,都能得到7的map。
那对应的转置卷积,到底要输出14的map,还是输出15的map呢?答案是15,但如果此时对应的卷积过程的输入时14呢?
我们上面看到的decoder过程恰恰就属于上面的情况,所以crop一下,去掉一个像素单位的边。
loss_criterion = nn.CrossEntropyLoss() loss_criterion_gan = nn.BCEWithLogitsLoss()
一个是标准交叉熵,另一个是二类交叉熵。
batch_ones_label = t.ones(conf.batch_size) # 真偽判別用のラベル
batch_zeros_label = t.zeros(conf.batch_size)
fixed_noise = t.FloatTensor(
np.random.uniform(-1, 1, (conf.batch_size, conf.nz)))
tmp = t.LongTensor(np.random.randint(conf.np, size=conf.batch_size))
pose_code = one_hot(tmp, conf.np) # Condition 付に使用
pose_code_label = t.LongTensor(tmp) # CrossEntropy 誤差に使用
前两行分别是真实样本和假样本的真假标签。
fixed_noise 就是论文中提到的 noise z。
tmp只是为了获得下面的pose code,是独热编码,就是把离散的pose当做分类问题。
下面是D的损失函数:
L_id = loss_criterion(real_output[:, :Nd], batch_id_label)
L_gan = loss_criterion_gan(real_output[:, Nd], batch_ones_label) + loss_criterion_gan(syn_output[:, Nd], batch_zeros_label)
L_pose = loss_criterion(real_output[:, Nd+1:], batch_pose_label)
d_loss = L_gan + L_id + L_pose
real_output就是真实样本送进D的输出,是一个Nd+1+Np维向量。D是要把真实样本分到前Nd个位置代表的类中,这就是第一行为啥这么写。和论文中损失函数也是对应的。
D还要把假样本分到第Nd+1的位置对应的类上,就是第二行做的事情,将真样本都当做正样本,把假样本都当做负样本。是二分类问题,所有使用的是二类交叉熵。
D还要把真样本的pose分类对,就是第三行做的事情。
这么一解释,G的损失函数相信看代码就明白了。
代码工程下载地址:here
查看reademe文件