生成对抗网络代码详解(二):cGAN

cGAN的全称为conditional GAN,是条件生成对抗网络,虽然利用GAN就可以学习真实的图片,从而生成一些逼真的图片,但是,生成数据的类别是无法控制的,因此需要某些条件来限制生成的数据,比如标签,因此,用于训练cGAN的数据总是成对的。生成对抗网络代码详解(二):cGAN_第1张图片将数据和标签同时进行训练,当虚假的数据配上任意的标签,质量很差的图片不管标签是什么,或真实的数据配上错误的标签,判别器都需要判定为假,只有当正确的数据配上正确的标签,判别器才判定为真。例如当输入生成数据为蝴蝶,对应的输入标签也为蝴蝶,判别器判别为真,那么此时的生成器是有效的,它成功的骗过了判别器。

上篇博客生成对抗网络代码详解(一):GAN已经介绍过如何创建一个基本的GAN,接下来,我只介绍一些cGAN与GAN的区别。

		layers.append(nn.Linear(in_features=28*28+10, out_features=512, bias=True))

在判别器的第一层输入特征不再是28*28,还需要要加10,增加的部分对应的是标签。

		layers.append(nn.Linear(in_features=z_dim+10, out_features=128))

对于生成器的第一层输入z_dim维的噪声维度同样也需要增加10维。

	def forward(self, x, c):
		x = x.view(x.size(0), -1)
		validity = self.model(torch.cat([x, c], -1))
		return validity
	def forward(self, z, c):
		x = self.model(torch.cat([z, c], dim=1))
		x = x.view(-1, 1, 28, 28)
		return x

在前向传播的过程中,还需要将条件c输入判别器和生成器,然后将输入的数据或噪声与条件c拼接在一起。

由于MNIST手写数字的标签为0-9,因此,在训练过程中需要将他们转换为one-hot类型:

def one_hot(labels, class_num):
	'''把标签转换成one-hot类型'''
	tmp = torch.FloatTensor(labels.size(0), class_num).zero_()
	one_hot = tmp.scatter_(dim=1, index=torch.LongTensor(labels.view(-1, 1)), value=1)
	return one_hot
# 生成 batch_size 个 ont-hot 标签
c = torch.FloatTensor(batch_size, 10).zero_()
c = c.scatter_(dim=1, index=torch.LongTensor(np.random.choice(10, batch_size).reshape([batch_size, 1])), value=1)
c = c.to(device)

你可能感兴趣的:(深度学习,GAN)