本文利用GAN模型去生成不同视角下的步态剪影图,并且解决背包等协变量的问题。
创新之处在与,相比于传统的GAN模型,其使用了两个不同的判别器:一个是传统的真伪辨别器,另一个用于维持步态剪影中含有的身份信息。
最大的挑战:生成的步态剪影图不仅要看上去是真的,还需要保持人体可辨别的身份信息。
插播关于GEI的制作步骤:
① 将二值轮廓图像缩放为统一的标准尺寸,即归一化
② 将一个步态序列的归一化图片加权求和生成一张图片(0-255范围)
GAN中有两个角色:**生成器G和判别器D,他们相互对抗共同训练。**其中,G类似生成假钞的罪犯,D类似分辨假钞的警察,G和D相互对抗,G生成的假币更加逼真,D的判别能力更加强大。训练的最终目标是使得D无法再判别出G生成数据的真伪(即达到纳什均衡)。
生成器G:接受一个噪声向量,生成数据
判别器D:接受真实数据,或者生成器G生成的数据,对他们进行二分类(分辨出真假)
训练过程:
其中需要注意的是,判别器D的生成是0-1之间的数据。例如D判别一个数据是100%真实数据,它将输出1。
· 训练判别器:在最初的iteration中,对D进行反复训练,使得D(真实数据) >> 1, D(G(噪声数据)) >> 0,从而logD(真实数据) >> 0,log(1-D(G(噪声数据)))>> 0,让loss最大(loss的最大值逼近于0)。
· 训练生成器:目的是让判别器出错,让 D(G(噪声数据)) >> 1,从而log(1-D(G(噪声数据)))>> 负无穷小,让loss最小(无穷小)。
GAN的输入也可以是图片而不是噪声向量。例如在PixelDTGAN中,它可以识别输入图片与目标图片之间的像素级转换,同时可以建立起输入域与目标域之间的语义含义,从而保证生成的图片看上去真实,同时维持其语义含义。
对于Real/fake D,其作用就是用于尽最大可能地让生成的数据看上去是真实的,它的输入是生成的数据以及真实的数据,产生一个概率值用于判断图像是由生成器生成的还是真实数据。例如:对于真实图片,label为1;对于生成器生成的图片,label为0,做二分类。
其损失函数为(最小化):
对于Domain D,起作用就是用于保持语义信息,它的输入是一对源图像的目标图像(一个相关的和一个不相关的),产生一个概率值用于判断这对图像是否关联。例如:对于一对源图片与目标图片,若相关联label为1,若不关联label为0,做二分类。
其损失函数为(最小化):
如果目标图像是与源图像相关的图像,让D(I_s, I) >> 1;如果目标图像是生成的图像,或者是与源图像无关的图像,让D(I_s, I) >> 0
与PixelDTGAN的想法类似,GaitGAN将所有视角以及带有协变量的GEI当做源图像,把正常的90°的GEI当做目标图像。
对于编码解码器,利用常规的CNN:
对于Real/fake D,其用于预测生成的图片是否为真实的图片。如果生成90°的NM图片,输出为1,否则为0。
对于Domain D(文中命名为identification discriminator),用于预测图片是否相关,如果目标图与源图片是同一个人,输出1;如果不是同一个人或者是生成出来的图片,则输出0。
在解码器中运用到了反卷积,此处做补充:
class NetG(nn.Module):
def __init__(self, nc=3, ngf=96):
super(NetG, self).__init__()
self.converter = nn.Sequential(
nn.Conv2d(nc, ngf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ngf*2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ngf*4),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ngf*8),
nn.LeakyReLU(0.2, True),
# 反卷积
nn.ConvTranspose2d(ngf*8, ngf*4,
kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
# 反卷积
nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
# 反卷积
nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 反卷积
nn.ConvTranspose2d(ngf, nc, kernel_size=4, stride=2, padding=1,
bias=False),
nn.Tanh()
)
def forward(self, x):
x = self.converter(x)
return x
class NetD(nn.Module):
def __init__(self, nc=3, ndf=96):
super(NetD, self).__init__()
self.discriminator = nn.Sequential(
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False),
nn.Sigmoid()
)
def forward(self, x):
x = self.discriminator(x)
return x.view(-1, 1)
'''
domain discriminator
'''
class NetA(nn.Module):
def __init__(self, nc=3, ndf=96):
super(NetA, self).__init__()
self.discriminator = nn.Sequential(
nn.Conv2d(nc*2, ndf, kernel_size=4, stride=2, padding=1,
bias=False),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1,
bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False),
nn.Sigmoid()
)
def forward(self, x):
x = self.discriminator(x)
return x.view(-1, 1)
'''更新生成器'''
lossG = 0
optimG.zero_grad()
fake = netg(img)
output = netd(fake) # 生成的图
label.fill_(real_label) # 骗过D,是真标签
lossGD = F.binary_cross_entropy(output, label)
lossG += lossGD.item()
lossGD.backward(retain_graph=True)
faked = th.cat((img, fake), 1)
output = neta(faked) # 生成的图
label.fill_(real_label) # 骗过A,是真标签
lossGA = F.binary_cross_entropy(output, label)
lossG += lossGA.item()
lossGA.backward()
optimG.step() # 一起更新D和A,骗过他们
'''更新fake/real 辨别器'''
label.fill_(real_label) # 真标签
output = netd(ass_label) # 相关联的图,是真图
lossD_real1 = F.binary_cross_entropy(output, label)
lossD += lossD_real1.item()
lossD_real1.backward()
label.fill_(real_label) # 真标签
output1 = netd(noass_label) # 不相关的图,但也是真的图
lossD_real2 = F.binary_cross_entropy(output1, label)
lossD == lossD_real2.item()
lossD_real2.backward()
label.fill_(fake_label) # 假标签
fake = netg(img).detach() # 生成的图
output2 = netd(fake)
lossD_fake = F.binary_cross_entropy(output2, label)
lossD += lossD_fake.item()
lossD_fake.backward()
optimD.step() # 更新D
'''更新Domain 辨别器'''
label.fill_(real_label) # 真标签
output1 = neta(assd) # 相关的图
lossA_real1 = F.binary_cross_entropy(output1, label)
lossA += lossA_real1.item()
lossA_real1.backward()
label.fill_(fake_label) # 假标签
output = neta(noassd) # 不相关的图
lossA_real2 = F.binary_cross_entropy(output, label)
lossA += lossA_real2.item()
lossA_real2.backward()
label.fill_(fake_label) # 假标签
output = neta(faked) # 生成的图
lossA_fake = F.binary_cross_entropy(output, label)
lossA += lossA_fake.item()
lossA_fake.backward()
optimA.step() # 更新A