什么是GAN网络?
GAN(Generative Adversarial Networks)的初衷就是生成不存在于真实世界的数据,类似于使得 AI具有创造力或者想象力。应用场景如下:
GAN网络有很多变形,下面主要介绍其中一种常用变形,ACGAN
那什么是ACGAN呢?
在计算机视觉里面,我们用的比较多的就是分类了,而训练分类的前提是收集足够多的各种分类的数据用来训练,这也是我们比较头疼的一个步骤,数据来源少无法训练怎么办?
ACGAN的一个用途就是用来生成多分类增强数据,只要你有每种分类数据大概2000张以上就能进行训练并生成指定分类的数据,下面是它的一个原理图:
如上图所示,ACGAN相对于GAN不同点在于:
所以表面上看还是挺简单的,不过深究到细节就有几点注意了:
1、针对输入多出的类别要怎样和噪声进行融合呢?我们首先想到的应该就是把类别和噪声进行连接成为新数组对吧,这种自己试过之后效果并不是特别好,因为该情况下类别无法深入影响到每个噪声变量。看一下下面这种方式:
def build_generator(self):
# generator负责生成图片,所以卷积过程是从小到大
model = Sequential()
row_shape = int(self.img_cols / 8)
col_shape = int(self.img_rows / 8)
# model.add(Dense(self.latent_dim * 8 * row_shape * col_shape, activation="relu", input_dim=self.latent_dim))
# model.add(Reshape((row_shape, col_shape, self.latent_dim * 8)))
# model.add(LeakyReLU(alpha=0.2))
# model.add(Conv2DTranspose(self.latent_dim * 4, 3, strides=2, padding='same'))
# model.add(LeakyReLU(alpha=0.2))
# model.add(Conv2DTranspose(self.latent_dim * 2, 3, strides=2, padding='same'))
# model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256 * row_shape * col_shape, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((row_shape, col_shape, 256)))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(self.channels, kernel_size=3, strides=2, padding='same'))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, 100)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
上面这种方式是通过keras提供的Embedding层进行融合,Embedding层可以看做和one-hot类似的形式,但是它比one-hot优势体现在Embedding生成的变量并不是指定位置为1,其他为0的形式,而是每个位置的值都是一个浮点数,简单来说Embedding其实相当于一个神经网络层,把输入映射到多维空间,这样做的好处是空间特征更加丰富
回到上面的例子,利用Embedding输出和噪声进行相乘能更好将类别信息融合到噪声里面,经过测试相同的代码和数据,采用Embedding结构的网络生成的类别会更准确一些。
2、针对输出增加了类别的判断,网络结构上面也改变了一些,首先是D网络的输出:
def build_discriminator(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.summary()
img = Input(shape=self.img_shape)
# Extract feature representation
features = model(img)
# Determine validity and label of the image
validity = Dense(1, activation="sigmoid")(features)
label = Dense(self.num_classes, activation="softmax")(features)
return Model(img, [validity, label])
可以看到原先GAN网络输出只有validity,现在多了一个label,不过也不是特别复杂,只是将原先的最后一层网络分别映射成真假和类别两个输出。
3、增加了输出,损失函数的更改更为重要:
optimizer = Adam(0.001, 0.5)
losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss=losses, optimizer=optimizer, metrics=['accuracy'])
原先GAN网络只要判断真假,所以选择binary这种多标签损失函数就行了,ACGAN增加了类别,所以还得加入categorial这种多类别损失函数,两个损失函数分别对应之前的两个输出,两个加起来的结果就是总的损失函数
介绍完ACGAN基本结构,我们来看一下它实际的效果吧,如下所示,我们训练下面四种枪支:
经过一番数据收集,准备好了上面标注的数据量,使用前面提到的ACGAN网络进行训练,下面是训练结果:
1)epoch0:
2) epoch2:
3) epoch23:
可以看到逐步有效果出来了,注意训练数据的图片宽高比尽量保持一致,这样输出的比较符合真实比例,还有一点就是没必要增加太多的数据增强,增加一点训练数据效果还来的更好一些,再来看最后一张,Epoch45:
不知道大家注意到了没有,不管训练多少回,第二种类别即Kar98K这一列偶尔还是会出现不对的类别,第三种即AWM更糟糕,基本很少是对的类别,但是第一种和最后一种却非常稳定,几乎没有出现错的类别,这是什么原因呢?
带着疑问大家可以再回顾一下训练前准备的训练数据情况,这时候应该就能发现有个规律,训练数据量越多的生成的效果也就越准确,而且有个分界点,就是数据量低于2000的明显效果不太好,高于2000的会正常一些。
正如刚开始介绍ACGAN作用的时候也说过,ACGAN的一个作用就是弥补数据的不足,所以正常情况下也不会有太大的数据作为ACGAN的训练数据,不然就失去意义了,但是一个底线是起码每个分类的训练数据不低于2000张,否则
总结
总的来说,ACGAN还是相当有意思的一种网络,它让计算机拥有了能够模仿现有数据从而生成独特数据的能力,用途也相当广泛,如果用来做数据增强的话,建议每种类别训练数据量尽量和上面的AKM保持一致,这样训练效果才不会太差,而且2000多张数据正常还是可以收集到的,当然个人比较看好的是它的“想象力”,希望后面能开发出变形以及用途。