我们用tensorflow应用于实际项目中时,常常会遇到一种情况:我们有很多的数据,但是只有很少的标注。因为标注需要很多时间。这时我们可能会想到用半监督(semi-supervise)的方法训练数据。但是半监督需要将无标签(unlabeled)的数据用于训练中,这是一个很困难的事情。恰好,最近有一种很火的方法——生成对抗网络(Generative Adversarial Nets,GAN)——中有关于半监督的方法应用,并且在论文中得到了很好的效果,我们参考它们尝试设计自己的半监督网络。
首先聊一聊GAN的发展历史。2016年,ImprovedTechniquesforTrainingGANs将GAN用于生成样本和半监督中(它并不是首例GAN,但它的代码引用是最多的),设计了两类损失包括监督损失和无监督损失,达到了比较好的训练精度。后来,随着时间的推移,人们发现KL散度用于度量GAN这种低维映射到高维的网络损失时,有个理论上的巨大漏洞,容易导致梯度消失,所以出现了WGAN;随后,又有人发现WGAN的权值剪切法会导致梯度极端分布,所以出现了WGAN-GP和WGAN-CT等方法,都是使用梯度惩罚项来实现Lipschitz连续。
下面我们看看它们的具体实现。代码参考。
SSGAN是所有介绍的方法中,最“老”的一种。但是,后续的半监督方法无一不参考了它的思想——将损失分为监督损失和无监督损失。
G_img = generator('gen', z, reuse=False)
d_logits_r, layer_out_r = discriminator('dis', x, reuse=False)
d_logits_f, layer_out_f = discriminator('dis', G_img, reuse=True)
# caculate the unsupervised loss
d_loss_r=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_r[:, -1]),logits=d_logits_r[:, -1]))
d_loss_f=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_f[:, -1])*0.9, logits=d_logits_f[:, -1]))
d_loss_f1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_f[:, -1]),logits=d_logits_f[:, -1]))
# feature match
f_match = []
for i in range(4):
f_match += [tf.reduce_mean(tf.multiply(layer_out_f[i]-layer_out_r[i], layer_out_f[i]-layer_out_r[i]))]
# caculate the supervised loss
s_label = tf.concat([label, tf.zeros(shape=(batch_size,1))], axis=1)
s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label, logits=d_logits_r))
d_loss = d_loss_r + d_loss_f + s_l_r*flag+d_regular
g_loss = d_loss_f1+0.1*tf.reduce_mean(f_match,0)
其中,s_l_r是有标签的损失,d_loss是判别器损失,g_loss是生成器损失。当有标签时,我们将d_loss中的flag设为1,没有时,设为0。
为什么虚拟标签需要乘以0.9,请参照论文内的One-sided label smoothing方法。feature matching 也一样。
wgan-ct相对于WGAN-GP(WGAN的改进型,略去不讲,请参考论文),使用最后两层网络的梯度惩罚。论文中说这样可以减少真实数据的不连续性(未完全理解)。之所以选择这个网络,是因为这个方法也做了半监督实验。代码如下(论文源码):
G_img = generator('gen', z, reuse=False)
d_logits_r1, d_logits_r11,d_logits_r12 = discriminator_with_dropout('dis', x, reuse=False)
d_logits_r2, d_logits_r21,_ = discriminator_with_dropout('dis', x, reuse=True)
d_logits_f, _ ,d_logits_f2 = discriminator_with_dropout('dis', G_img, reuse=True)
# caculate the unsupervised loss
logits_r, logits_f = tf.nn.softmax(d_logits_r1), tf.nn.softmax(d_logits_f)
d_loss_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_r[:, -1]), logits=d_logits_r1[:, -1]))
d_loss_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_f[:, -1]), logits=d_logits_f[:, -1]))
logits=tf.reduce_max(d_logits_r1, -1))
loss_ct=tf.square(d_logits_r1-d_logits_r2)
loss_ct_=0.1*tf.reduce_mean(tf.square(d_logits_r11-d_logits_r21))
CT=loss_ct+loss_ct_
# caculate the supervised loss
s_label = tf.concat([label, tf.zeros(shape=(batch_size,1))], axis=1)
s_l_r = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=s_label, logits=d_logits_r1))
d_l_1, d_l_2 = d_loss_r + d_loss_f, s_l_r
d_loss = d_loss_r + d_loss_f + s_l_r*flag +0.1*tf.reduce_mean(CT)
g_loss = tf.square(tf.reduce_mean(d_logits_f2,0)-tf.reduce_mean(d_logits_r12,0))
all_vars = tf.global_variables()
for v in all_vars:
print(v)
all_vars = tf.global_variables()
g_vars = [v for v in all_vars if 'gen' in v.name]
d_vars = [v for v in all_vars if 'dis' in v.name]
opt_d = tf.train.AdamOptimizer(lr).minimize(d_loss, var_list=d_vars)
opt_g = tf.train.AdamOptimizer(lr).minimize(g_loss, var_list=g_vars)
和SSGAN一样,使用flag作为控制带标签和不带标签的开关。其中CT项是该方法的创新,具体方法为:通过加入dropout,由同一输入得到不同的输出,然后将网络最后两层的输出做差分,作为梯度惩罚项。
我们拿minist(深度学习界的果蝇)做一下实验,使用同样的判别器和生成器,对比SSGAN和WGAN-CT的结果。
首先,作为比较结果,我们得到仅使用带标签数据的训练成果。
其中,每1000次迭代(iteration)更新一次带标签数据,带标签数据一共100个。一次迭代的batch size为50,即一次训练的样本数为50000。
然后,使用同样的判别网络,与数据输入方式,SSGAN的训练精度为
可以看到,最高精度反而降低了。
以下是SSGAN的生成样本。
和WGAN中讨论的一样,生成的样本面临多样性不足的问题(训练多次后有所缓解)。
最后,是WGAN-CT的半监督测试精度结果:
相对于带监督结果,精度只提高了一点。
以下是生成的样本
多样性不足的问题有所缓解。这个实验我没有做满50000次,因为CT-GAN的由于需要计算梯度,所以反向训练比较耗时,感兴趣的读者可以参考我提供的代码地址,自己使用minist复现结果。
两种半监督方法的效果不是很好,比不上无监督方法。其实这也很好理解,因为它们只是沿用了传统GAN的思路定义了损失(实际上只是多出来一个虚拟类),但是并没有思考生成出来的图片如何进一步为判别器所用,提升判别器的精度。
下一步需要做的:
1.Cifar数据集测试.
2.思考两种方法的损失改进方法,使生成的图片能够用于判别器的分类中(不仅仅是真实和虚假两类)。
3.增加VAT方法.
最后,祝您身体健康,再见!