生成对抗网络是指一类采用对抗训练方式进行学习的深度生成模型,包含的判别网络和生成网络都可以根据不同的生成任务使用不同的网络结构。
生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。
构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。
GAN训练过程:
生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。
DCGAN模型:
generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
activation="tanh"))
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
activation=LeakyReLU(0.3),
input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))
模型训练:
GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = False
GAN.compile(optimizer='adam',loss='binary_crossentropy')
epochs = 150
batch_size = 100
noise_shape=100
with tf.device('/gpu:0'):
for epoch in range(epochs):
print(f"Currently on Epoch {epoch+1}")
for i in range(X_train.shape[0]//batch_size):
if (i+1)%50 == 0:
print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")
noise=np.random.normal(size=[batch_size,noise_shape])
gen_image = generator.predict_on_batch(noise)
train_dataset = X_train[i*batch_size:(i+1)*batch_size]
train_label=np.ones(shape=(batch_size,1))
discriminator.trainable = True
d_loss_real=discriminator.train_on_batch(train_dataset,train_label)
train_label=np.zeros(shape=(batch_size,1))
d_loss_fake=discriminator.train_on_batch(gen_image,train_label)
noise=np.random.normal(size=[batch_size,noise_shape])
train_label=np.ones(shape=(batch_size,1))
discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned off
d_g_loss_batch =GAN.train_on_batch(noise, train_label)
if epoch % 10 == 0:
samples = 10
x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))
for k in range(samples):
plt.subplot(2, 5, k+1)
plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()
print('Training is complete')
使用np.random.normal生成的噪声被作为输入给发生器:
noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')