深度学习【41】InfoGAN

InfoGAN利用互信息和变分自编码(VAE,参考我这篇博客)将样本的label信息加入了GAN中。

首先我们列出一些基本公式:
熵: H(X)=XP(X)logP(X) H ( X ) = − ∫ X P ( X ) l o g P ( X )
H(Y|X)=XP(X)YP(Y|X)logP(Y|X) H ( Y | X ) = − ∫ X P ( X ) ∫ Y P ( Y | X ) l o g P ( Y | X )
仔细回想一下期望的计算公式,我们发现H(X) = E(log(P(X))),这个后面会用到。

互信息: I(X;Y)=H(X)H(Y|X) I ( X ; Y ) = H ( X ) − H ( Y | X )
KL散度: DKL(X||Y)=XlogXY D K L ( X | | Y ) = ∫ X l o g X Y

好了,我们步入正题。为了让G网络生成的图片更有语义意思,论文中在z向量的基础上加入了隐藏编码c变量。隐藏编码c表示了所要生成图片的语义,比如mnist中的0-9数字。以mnist为例,隐藏编码c,其实就是一个one-hot向量,比如要生成内容为7的图片,则c向量中的第7索引则为1,其余的位置为0。因此来自G网络的图片可写成G(z,c)。
为了确保隐藏编码c能够起到引导G网络生成的图片有语义信息。论文使用了互信息 I(c;G(z,c)) I ( c ; G ( z , c ) ) ,以降低给定G(z,c)后c的不确定性。这样一来,InfoGAN的优化目标函数为:
这里写图片描述
就是在原始的GAN损失函数上加入一个互信息函数。下文我们只关心互信息函数。

在实际中,I(c;G(z,c))是很难优化的,这是因为我们根本不是知道后验概率 P(c|x) P ( c | x ) 。幸运的是,根据变分理论,我们可以用一个比 P(c|x) P ( c | x ) 低界的 Q(c|x) Q ( c | x ) 函数来近似 P(c|x) P ( c | x )

有了上面的Q函数作为指导,我们开始推导一下优化目标函数。

我们直接贴出论文的结果:
深度学习【41】InfoGAN_第1张图片
论文还给出了一个定理:
这里写图片描述

根据期望计算公式、KL散度公式以及论文给出的定理,我们来推导一下H(c;G(z,c)):

H(c;G(z,c))=ExG(z,c)[EcP(c)[logP(c|x)]]=ExG(z,c)[EcP(c|x)[logP(c|x)]]Q=ExG(z,c)[EcP(c|x)[logP(c|x)Q(c|x)+logQ(c|x)]]=ExG(z,c)[EcP(c|x)[logP(c|x)Q(c|x)]+EcP(c|x)[logQ(c|x)]]KL=ExG(z,c)[DKL(P(|x);Q(|x))+EcP(c|x)[logQ(c|x)]] H ( c ; G ( z , c ) ) = E x ∼ G ( z , c ) [ E c ∼ P ( c ) [ l o g P ( c | x ) ] ] = E x ∼ G ( z , c ) [ E c ′ ∼ P ( c | x ) [ l o g P ( c ′ | x ) ] ] 将 Q 函 数 加 入 : = E x ∼ G ( z , c ) [ E c ′ ∼ P ( c | x ) [ l o g P ( c ′ | x ) Q ( c ′ | x ) + l o g Q ( c ′ | x ) ] ] = E x ∼ G ( z , c ) [ E c ′ ∼ P ( c | x ) [ l o g P ( c ′ | x ) Q ( c ′ | x ) ] + E c ′ ∼ P ( c | x ) [ l o g Q ( c ′ | x ) ] ] 结 合 期 望 公 式 和 K L 计 算 公 式 , 有 : = E x ∼ G ( z , c ) [ D K L ( P ( ⋅ | x ) ; Q ( ⋅ | x ) ) + E c ′ ∼ P ( c | x ) [ l o g Q ( c ′ | x ) ] ]

由因为论文将H(c)当做常数处理,因此最后变成最大化 log(Q(|x)) l o g ( Q ( ⋅ | x ) ) 的期望。

在实际使用中,论文对隐藏编码c分为了离散和连续两种情况。以mnist数集来说,离散的c表示类别,连续的c则是用来控制生成字体的风格(论文用一个大小为2的向量)。

那么如何将Q函数的优化引入GAN中呢?论文的做法是将D网络从中间分出一个Q网络分支。而Q网络分别预测出离散的 c c ′ ,以及连续的 c c ′ 的均值和方差。在最大化 log(Q(c|x)) l o g ( Q ( c ′ | x ) ) 时,对于离散的隐藏编码则利用真实的离散c直接用交叉熵损失函数进行优化。而对于连续的隐藏编码,则根据预测出来的均值、方差以及一开始由随机函数生成的连续c,输入对数高斯函数,最大化对数高斯函数。

连续情况的损失函数代码:

class log_gaussian:

  def __call__(self, x, mu, var):

    logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \
            (x-mu).pow(2).div(var.mul(2.0)+1e-6)

    return logli.sum(1).mean().mul(-1)

你可能感兴趣的:(深度学习)