上期我们介绍了
生成对抗网络 | 原理及训练过程
同样地,我们依旧通过实验来巩固我们刚刚所学的知识点。本次实验是基于Jupyer Notebook、Anaconda Python3.7与Keras环境。数据集是利用Minst手写体图像数据集。
5.3.1 代码
1. # chapter5/5_3_GAN.ipynb
2. import random
3. import numpy as np
4. from keras.layers import Input
5. from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten
6. from keras.layers.advanced_activations import LeakyReLU
7. from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D, Deconv2D, UpSampling2D
8. from keras.regularizers import *
9. from keras.layers.normalization import *
10. from keras.optimizers import *
11. from keras.datasets import mnist
12. import matplotlib.pyplot as plt
13. from keras.models import Model
14. from tqdm import tqdm
15. from IPython import display
1. 读取数据集
1. img_rows, img_cols = 28, 28
2.
3. # 数据集的切分与混洗(shuffle)
4. (X_train, y_train), (X_test, y_test) = mnist.load_data()
5.
6. X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
7. X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
8. X_train = X_train.astype('float32')
9. X_test = X_test.astype('float32')
10. X_train /= 255
11. X_test /= 255
12.
13. print(np.min(X_train), np.max(X_train))
14. print('X_train shape:', X_train.shape)
15. print(X_train.shape[0], 'train samples')
16. print(X_test.shape[0], 'test samples')
0.0 1.0
X_train shape: (60000, 1, 28,28)
60000 train samples
10000 test samples
2. 超参数设置
1. shp = X_train.shape[1:]
2. dropout_rate = 0.25
3.
4. # 优化器
5. opt = Adam(lr=1e-4)
6. dopt = Adam(lr=1e-5)
3. 定义生成器
1. K.set_image_dim_ordering('th') # 用theano的图片输入顺序
2. # 生成1 * 28 * 28的图片
3. nch = 200
4. g_input = Input(shape=[100])
5. H = Dense(nch*14*14, kernel_initializer='glorot_normal')(g_input)
6. H = BatchNormalization()(H)
7. H = Activation('relu')(H)
8. H = Reshape( [nch, 14, 14] )(H) # 转成200 * 14 * 14
9. H = UpSampling2D(size=(2, 2))(H)
10. H = Convolution2D(100, (3, 3), padding="same", kernel_initializer='glorot_normal')(H)
11. H = BatchNormalization()(H)
12. H = Activation('relu')(H)
13. H = Convolution2D(50, (3, 3), padding="same", kernel_initializer='glorot_normal')(H)
14. H = BatchNormalization()(H)
15. H = Activation('relu')(H)
16. H = Convolution2D(1, (1, 1), padding="same", kernel_initializer='glorot_normal')(H)
17. g_V = Activation('sigmoid')(H)
18. generator = Model(g_input,g_V)
19. generator.compile(loss='binary_crossentropy', optimizer=opt)
20. generator.summary()
4. 定义辨别器
1. # 辨别是否来自真实训练集
2. d_input = Input(shape=shp)
3. H = Convolution2D(256, (5, 5), activation="relu", strides=(2, 2), padding="same")(d_input)
4. H = LeakyReLU(0.2)(H)
5. H = Dropout(dropout_rate)(H)
6. H = Convolution2D(512, (5, 5), activation="relu", strides=(2, 2), padding="same")(H)
7. H = LeakyReLU(0.2)(H)
8. H = Dropout(dropout_rate)(H)
9. H = Flatten()(H)
10. H = Dense(256)(H)
11. H = LeakyReLU(0.2)(H)
12. H = Dropout(dropout_rate)(H)
13. d_V = Dense(2,activation='softmax')(H)
14. discriminator = Model(d_input,d_V)
15. discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
16. discriminator.summary()
5. 构造生成对抗网络
1. # 冷冻训练层
2. def make_trainable(net, val):
3. net.trainable = val
4. for l in net.layers:
5. l.trainable = val
6. make_trainable(discriminator, False)
7.
8. # 构造GAN
9. gan_input = Input(shape=[100])
10. H = generator(gan_input)
11. gan_V = discriminator(H)
12. GAN = Model(gan_input, gan_V)
13. GAN.compile(loss='categorical_crossentropy', optimizer=opt)
14. GAN.summary()
6. 训练
1. # 描绘损失收敛过程
2. def plot_loss(losses):
3. display.clear_output(wait=True)
4. display.display(plt.gcf())
5. plt.figure(figsize=(10,8))
6. plt.plot(losses["d"], label='discriminitive loss')
7. plt.plot(losses["g"], label='generative loss')
8. plt.legend()
9. plt.show()
10.
11.
12. # 描绘生成器生成图像
13. def plot_gen(n_ex=16,dim=(4,4), figsize=(10,10) ):
14. noise = np.random.uniform(0,1,size=[n_ex,100])
15. generated_images = generator.predict(noise)
16.
17. plt.figure(figsize=figsize)
18. for i in range(generated_images.shape[0]):
19. plt.subplot(dim[0],dim[1],i+1)
20. img = generated_images[i,0,:,:]
21. plt.imshow(img)
22. plt.axis('off')
23. plt.tight_layout()
24. plt.show()
25.
26. # 抽取训练集样本
27. ntrain = 10000
28. trainidx = random.sample(range(0,X_train.shape[0]), ntrain)
29. XT = X_train[trainidx,:,:,:]
30.
31. # 预训练辨别器
32. noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100])
33. generated_images = generator.predict(noise_gen) # 生成器产生样本
34. X = np.concatenate((XT, generated_images))
35. n = XT.shape[0]
36. y = np.zeros([2*n,2]) # 构造辨别器标签 one-hot encode
37. y[:n,1] = 1
38. y[n:,0] = 1
39.
40. make_trainable(discriminator,True)
41. discriminator.fit(X,y, epochs=1, batch_size=32)
42. y_hat = discriminator.predict(X)
1. # 计算辨别器的准确率
2. y_hat_idx = np.argmax(y_hat,axis=1)
3. y_idx = np.argmax(y,axis=1)
4. diff = y_idx-y_hat_idx
5. n_total = y.shape[0]
6. n_right = (diff==0).sum()
7.
8. print( "(%d of %d) right" % (n_right, n_total))
1. def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32):
2. for e in tqdm(range(nb_epoch)):
3.
4. # 生成器生成样本
5. image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:]
6. noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100])
7. generated_images = generator.predict(noise_gen)
8.
9. # 训练辨别器
10. X = np.concatenate((image_batch, generated_images))
11. y = np.zeros([2*BATCH_SIZE,2])
12. y[0:BATCH_SIZE,1] = 1
13. y[BATCH_SIZE:,0] = 1
14.
15. # 存储辨别器损失loss
16. make_trainable(discriminator,True)
17. d_loss = discriminator.train_on_batch(X,y)
18. losses["d"].append(d_loss)
19.
20. # 生成器生成样本
21. noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100])
22. y2 = np.zeros([BATCH_SIZE,2])
23. y2[:,1] = 1
24.
25. # 存储生成器损失loss
26. make_trainable(discriminator,False) # 辨别器的训练关掉
27. g_loss = GAN.train_on_batch(noise_tr, y2)
28. losses["g"].append(g_loss)
29.
30. # 更新损失loss图
31. if e%plt_frq == plt_frq-1:
32. plot_loss(losses)
33. plot_gen()
34. train_for_n(nb_epoch=1000, plt_frq=10,BATCH_SIZE=128)
5.3.2 结果分析
从模型输出的loss我们可以知道生成器与辨别器两者拟合的loss并不是特别地好,因此我们可以通过调参来解决。主要调参方向有以下四点:
1. batch size
2. adam优化器的learning rate
3. 迭代次数nb_epoch
4. 生成器generator和辨别器discriminator的网络结构
5.4 小结
好了,到这里,我们就已经将生成对抗网络(GAN)的知识点讲完了。大家在掌握了整个流程之后,就可以将笔者的代码修改成自己所需要的场景,进而训练自己的GAN模型了。
最后,笔者在本章介绍的GAN只是2014年的开山之作,后面有很多人基于GAN提出了许多有趣的实验,但是所用的网络原理都差不多,这里就不一一赘述了。而且GAN的应用范围非常广阔,比如市面上很火的“换脸”软件,大多都是基于GAN的原理去做的。甚至我们也可以利用GAN去做数据增强,比如在我们缺少训练集的时候,可以考虑用GAN去生成一些数据,扩充我们的训练样本。
下一期,我们将讲授
李宏毅老师的无监督学习讲座PPT总结
敬请期待~
关注我的微信公众号~不定期更新相关专业知识~
内容 |阿力阿哩哩
编辑 | 阿璃
点个“在看”,作者高产似那啥~