生成器的最后一层不使用 sigmoid
,使用 tanh
代替
使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布
训练 discriminator 的时候,将 fake img
的标签设为 1
,real img
的标签设成 0
,这样更有利于其训练
在 generator 和 discriminator 设计的时候都使用 dropout
层来增加随机性;或者在 discriminator 的标签中 添加噪声 来提高随机性;因为随机性对于 GAN 的训练有帮助
在 discriminator 和 generator 中都使用 LeakyRelu
来作为激活函数而不用传统的 Relu
用 Conv2DTranspose,stride=2
来代替上采样操作;用 Conv2D,stride=2
来代替下采样操作(maxpooling)
在生成的图像中,经常会见到棋盘状伪影,这是由于生成器中像素空间的不均匀覆盖造成的(如下图),为了解决这个问题,每当生成器和判别器中都使用步进的Conv2DTranspose或Conv2D时,使用的 内核大小要能够被步幅大小整除。例如 stride=2,kernel=(4,4)
discriminator 的容量和能力一定要小于 generator,因为判别远比生成容易,如果 discriminator 太强了,反而不利于 generator 的学习,就像一个太过于严厉的老师是不利于学生大踏步地进行创新和进步的,老师一定要温和。比例大约控制在 generator 的容量是 discriminator 的 10
倍左右。
在训练的时候,如果不想调整网络的参数,那么可以尝试在训练的时候让一个网络训练好几次,然后另外一个网络训练一次,例如,如果 generator 很强,那么久让 discriminator 训练 3 次,generator 更新一次参数,加个 for 循环即可,亲测有效。
generator 在训练的过程中,前期大概率会被 discriminator 压制;因为刚开始生成的东西还很简单,因此要保证在训练的时候 generator 在前几个 epochs 不能 loss 很快地上升到 1,这样的话不利于后面的训练 ,正确的引导方式应该是设计 generator 的 loss 先上升到 0.7,0.8 左右,然后再慢慢降下来,这样就很有利于训练 GAN 网络
一个被良好训练的 GAN 网络应该具有下面的 loss 走向:鉴别器和生成器都在波动,而不是一方的 loss 很快上升到 1,而另一方很快降到 0;
定义网络的顺序是:
(dis.trainable=False)
,这时候千万不要再次 compile discriminator!!如果这里你没看懂,一定要在下面的代码中仔细留意关于 discriminator 的compile 部分,因为如果这里出问题,你最终即使网络训练的时候通过调整参数而达到了上图中演示的那种良好的交替波动情况,最后的输出也大概率属于下图中的情况,如果你遇到了下图的情况,请仔细考虑你的 compile 步骤:
在训练 GAN 之前,尽量对数据进行标准化,图片数据除以 255.
即可,下面代码中也有演示,对数据规范化绝对是一个好习惯
import keras,os
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from keras.preprocessing import image
from keras.datasets import fashion_mnist,cifar10,cifar100,mnist
from keras.utils import to_categorical
os.environ["CUDA_VISIBLE_DEVICES"] = " 2"
def generator(input_shape):
inputs = Input(input_shape)
x = Dense(128 * 16 * 16)(inputs)
x = LeakyReLU()(x)
x = Reshape((16, 16, 128))(x)
x = Conv2D(256, 5, padding = 'same')(x)
x = LeakyReLU()(x)
x = Conv2DTranspose(256, 4, strides = 2, padding = 'same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding = 'same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding = 'same')(x)
x = LeakyReLU()(x)
x = Conv2D(3, 7, activation='tanh', padding = 'same')(x)
return Model(inputs,x)
gen = generator((100,))
def discriminator(input_shape):
inputs = Input(input_shape)
x = Conv2D(128, 3)(inputs)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x) #分类层
return Model(inputs,x)
dis = discriminator((32,32,3))
dis.compile(loss=keras.losses.binary_crossentropy,optimizer= keras.optimizers.RMSprop(lr = 0.0008,clipvalue = 1.0,decay=1e-8))
def GAN():
dis.trainable=False
gan_input = Input((100,))
fake_image = gen(gan_input)
score = dis(fake_image)
return Model(gan_input,score)
gan = GAN()
gan.compile(loss=keras.losses.binary_crossentropy,optimizer=keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8))
(x_train,y_train),(x_test,y_test)= cifar10.load_data()
y_train_label = y_train
y_test_label = y_test
x_train = x_train[y_train.flatten() == 7] #选择马类数据即可
x_train = x_train.reshape(5000,32,32,3).astype('float32')/255.
epochs = 4000
batch_size = 64
valid = np.ones((batch_size,1))
fake = np.zeros((batch_size,1))
generated_img = []
discriminator_loss = []
generator_loss = []
save_dir = './A-GAN-PHOTO'
for epoch in range(epochs):
noise = np.random.normal(0,1,size=(batch_size,100))
img_index = np.random.randint(0,5000,batch_size)
fake_img = gen.predict(noise)
real_img = x_train[img_index]
data = np.concatenate([fake_img, real_img])
label = np.concatenate([fake,valid])
label += 0.05 * np.random.random(label.shape)
d_loss = dis.train_on_batch(data,label)
# ---------------------
# 训练生成模型
# ---------------------
noise_ = np.random.normal(0,1,size=(batch_size,100))
g_loss = gan.train_on_batch(noise_, valid)
if epoch%100 == 0:
im = fake_img[0]
generated_img.append(im)
img = image.array_to_img(fake_img[0] * 255, scale=False)
img.save(os.path.join(save_dir, 'generated_horse' + str(epoch) + '.png')) #保存一张生成图像
img = image.array_to_img(real_img[0] * 255, scale=False)
img.save(os.path.join(save_dir, 'real_horse' + str(epoch) +'.png')) #保存一张真实图像用于对比
print('discriminator_loss:',d_loss)
print('adversal_loss:',g_loss)
discriminator_loss.append(d_loss)
generator_loss.append(g_loss)
# discriminator_loss.append(d_loss[-1])
# generator_loss.append(g_loss[-1])
# print("d_loss:%f"%d_loss[-1])
# print("g_loss:%f"%g_loss[-1])
print("epoch:%d" % epoch + "========")
fig, axes = plt.subplots(nrows=2, ncols=20, sharex=True, sharey=True, figsize=(80,12))
imgs = generated_img
for image, row in zip([imgs[:20], imgs[20:40]], axes):
for img, ax in zip(image, row):
ax.imshow(img)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)
plt.plot(discriminator_loss,label='discriminator_loss')
plt.plot(generator_loss,label='generator_loss')
plt.legend()
最后,祝大家的 GAN 网络都能训练有素,耗子尾汁,别那么不讲武德。