InfoGAN-无监督式GAN

1. InfoGAN简介:

       普通的GAN存在无约束、不可控、噪声信号z很难解释等问题,2016年发表在NIPS顶会上的文章InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets,提出了InfoGAN的生成对抗网络。InfoGAN 主要特点是对GAN进行了一些改动,成功地让网络学到了可解释的特征,网络训练完成之后,我们可以通过设定输入生成器的隐含编码来控制生成数据的特征。

        作者将输入生成器的随机噪声分成了两部分:一部分是随机噪声Z, 另一部分是由若干隐变量拼接而成的latent code c。其中,c会有先验的概率分布,可以离散也可以连续,用来代表生成数据的不同特征。例如:对于MNIST数据集,c包含离散部分和连续部分,离散部分取值为0~9的离散随机变量(表示数字),连续部分有两个连续型随机变量(分别表示倾斜度和粗细度)。

        为了让隐变量c能够与生成数据的特征产出关联,作者引入了互信息来对c进行约束,因为c对生成数据G(z, c)具有可解释性,那么c和G(z, c)应该具有较高的相关性,即它们之间的互信息比较大。互信息是两个随机变量之间依赖程度的度量,互信息越大就说明生成网络在根据c的信息生成数据时,隐编码c的信息损失越低,即生成数据保留的c的信息越多。因此,我们希望c和G(z, c)之间的互信息I(c; G(z, c))越大越好,故模型的目标函数变为:

        但是由于在c与G(z, c)的互信息的计算中,真实的P(c|x)难以获得,因此在具体的优化过程中,作者采用了变分推断的思想,引入了变分分布Q(c|x)来逼近P(c|x),它是基于最优互信息下界的轮流迭代实现最终的求解,于是InfoGAN的目标函数变为:


2. InfoGAN的基本结构为:

InfoGAN的基本结构

        其中,真实数据Real_data只是用来跟生成的Fake_data混合在一起进行真假判断,并根据判断的结果更新生成器和判别器,从而使生成的数据与真实数据接近。生成数据既要参与真假判断,还需要和隐变量C_vector求互信息,并根据互信息更新生成器和判别器,从而使得生成图像中保留了更多隐变量C_vector的信息。

        因此可以对InfoGAN的基本结构进行如下的拆分,其中判别器D和Q共用所有卷积层,只是最后的全连接层不同。从另一个角度来看,G-Q联合网络相当于是一个自编网络,G相当于一个编码器,而Q相当于一个解码器,生成数据Fake_data相当于对输入隐变量C_vector的编码。

InfoGAN的拆分结构

生成器G的输入为:(batch_size, noise_dim + discrete_dim + continuous_dim),其中noise_dim为输入噪声的维度,discrete_dim为离散隐变量的维度,continuous_dim为连续隐变量的维度。生成器G的输出为(batch_size, channel, img_cols, img_rows)。

判别器D的输入为:(batch_size, channel, img_cols, img_rows),判别器D的输出为:(batch_size, 1)。

判别器Q的输入为:(batch_size, channel, img_cols, img_rows),Q的输出为:(batch_size, discrete_dim + continuous_dim)


3. InfoGAN的优化目标函数为:

        InfoGAN的目标函数变为:     

        对于判别器D而言,优化目标函数为:    

D_real, _, _ = Discriminator(real_imgs)        #  real_imgs为用于训练的真实图像

gen_imgs = Generator(noise, c_discrete, c_continuous)        #  c_discrete为输入的离散型隐变量, c_continuous为输入的连续型隐变量  

D_fake, _, _ = Discriminator(gen_imgs) 

D_real_loss = torch.nn.BCELoss(D_real, y_real)          #  y_real 真实图像的标签,都为1

D_fake_loss = torch.nn.BCELoss(D_fake, y_fake)         # y_fake为生成图像的标签,都为0

D_loss = D_real_loss + D_fake_loss

        对于生成器G而言,优化目标函数为: 

gen_imgs = Generator(noise, c_discrete, c_continuous)                   #  c_discrete为输入的离散隐变量,c_continuous为输入的连续隐变量

D_fake, D_continuous, D_discrete = Discriminator(gen_imgs) 

G_loss = torch.nn.BCELoss( D_fake, y_real)                #  y_real 真实图像的标签,都为1

         对于G-Q联合网络而言,它的优化目标函数为:  , 其中              因此,

discrete_loss = torch.nn.CELoss(D_discrete, c_discrete)

continuous_loss = torch.nn.MSELoss(D_continuous, c_continuous)

info_loss = discrete_loss + continuous_loss

info_loss.backward()

info_optimizer.step()       # 其中,info_optimizer = optim.Adam(itertools.chain(Generator.parameters(), Discriminator.parameters()), lr = learning_rate, betas=(beta1, beta2))

       简而言之,InfoGAN中单独判别器D的优化目标函数只有对抗损失,单独生成器G的优化目标函数也只有对抗损失,生成器G和辅助判别器Q联合网络的优化目标函数是info损失,包含离散损失和连续损失两个部分。其中,判别器D和辅助判别器Q共用卷积层,只是最后的全连接层不同。


参考链接:http://aistudio.baidu.com/aistudio/projectdetail/29156 中山大学黄涛对论文InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets的复现。

你可能感兴趣的:(InfoGAN-无监督式GAN)