目前仍然在在广州的实习公司继续实习,为了更好的完成任务,以及未来的开题,现在必须仔细学习GAN。之前将GAN和DCGAN两篇论文仔细阅读完了,之后为了检验学习成果写下了这份DCGAN生成手写数字的代码。
虽然是GAN系列的第一篇文章,本想着先从GAN最初论文说起,但是由于很久没更新了博客或者公众号了,想赶紧更新一篇回馈粉丝。巧合的是,GAN的结果展示多为各种图片,转念一想利用讲解代码和展示结果方式来引导GAN系列的开始也不为一个合适的选择。下面开始介绍利用DCGAN生成书写数字。
虽然GAN系列第一篇上来就讲解代码,着实让很多小白们难易难以接受。因此我也首先简单介绍一下GAN的原理。
GAN(Generative Adversarial Network)全名叫做对抗生成网络或者生成对抗网络。GAN这一概念是由Ian Goodfellow于2014年提出,并迅速成为了非常火热的研究话题。目前,GAN的变种更是有上千种,2019年计算机界的诺贝尔奖“图灵奖”得主,深度学习先驱之一的Yann LeCun也曾说:“GAN及其变种是数十年来机器学习领域最有趣的想法。”
GAN的主要思想是零和博弈,GAN有两部分组成,一个生成器和一个判别器。生成器主要用于生成图像,判别器用于判别图像是否是“假的”,即图像是否由生成器的概率。GAN的训练可以看成式生成器与判别器之间相互对抗的过程。那么最理想的结果是生成器生成的图像在判别器的预测结果为0.5,即分不清图像是真实图像还是生成器生成的图像。
在原始GAN中,判别器与生成器都是原始的多层感知机即BP神经网络,在DCGAN模型中,BP神经网络都被卷积神经网络所替换。生成器主要是利用一系列反卷积操作将一维噪声向量转化成图像,判别器则是正常的卷积神经网络,将图像进行一系列提取特征之后在判断该图像来自生成器的概率。
接下来,我们来介绍利用DCGAN生成手写数字图像。本篇文章的代码全部使用keras进行编写,后端使用的是tensorflow1.14。该项目的源代码网址请移步:DCGAN-mnist。
首先给出DCGAN的类代码,这份代码主要由初始化函数、生成器搭建函数、判别器搭建函数,DCGAN的训练函数和保存DCGAN生成的图片的函数5部分构成。代码如下所示:
# -*- coding: utf-8 -*-
# @Time : 2019/9/15 9:26
# @Author : DaiPuWei
# @Email : [email protected]
# @Blog : https://daipuweiai.blog.csdn.net/
# @File : DCGAN.py
# @Software: PyCharm
import os
import numpy as np
from scipy.stats import truncnorm
from keras import Model
from keras import Input
from keras import Sequential
from keras.layers import Conv2D
from keras.layers import BatchNormalization
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout
from keras.layers import Reshape
from keras.layers import Dense
from keras.layers import Flatten
from keras.optimizers import Adam
from keras.datasets import mnist
import matplotlib.pyplot as plt
class DCGAN(object):
def __init__(self,config):
"""
这是DCGAN的初始化函数
:param config: 网络模型参数配置类
"""
# 初始化网络相关超参数类
self.config = config
# 构建生成器与判别器
self.generotor_model = self.build_generator_model()
self.discriminator_model = self.build_discriminator_model()
# 构建DCGAN的优化器,并编译判别器
self.optimizier = Adam(lr=self.config.init_learning_rate,
beta_1=self.config.beta1,
decay=1e-8)
self.discriminator_model.compile(loss='binary_crossentropy',
optimizer=self.optimizier, metrics=['accuracy'])
# 构建DCGAN模型并进行编译
dcgan_input = Input(shape=self.config.generator_input_dim)
dcgan_output = self.discriminator_model(self.generotor_model(dcgan_input))
self.discriminator_model.trainable = False
self.dcgan = Model(dcgan_input,dcgan_output)
self.dcgan.compile(optimizer=self.optimizier, loss='binary_crossentropy', metrics=['accuracy'])
def build_generator_model(self):
"""
这是构建生成器网络的函数
:return:返回生成器模型generotor_model
"""
model = Sequential()
model.add(Dense(256*7*7,input_shape=self.config.generator_input_dim))
model.add(BatchNormalization(momentum=self.config.BatchNormalization_Momentum))
model.add(Activation('relu'))
model.add(Reshape((7,7,256)))
#generotor_model.add(UpSampling2D(size=(2,2)))
#generotor_model.add(Conv2D(64,5,5,padding='same'))
model.add(Conv2DTranspose(128,kernel_size=3,strides=2,padding='same'))
model.add(BatchNormalization(momentum=self.config.BatchNormalization_Momentum))
model.add(Activation('relu'))
model.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=self.config.BatchNormalization_Momentum))
model.add(Activation('relu'))
model.add(Conv2DTranspose(32, kernel_size=3,padding='same'))
model.add(BatchNormalization(momentum=self.config.BatchNormalization_Momentum))
model.add(Activation('relu'))
model.add(Conv2DTranspose(self.config.discriminator_input_dim[2], kernel_size=3,padding='same'))
model.add(Activation('tanh'))
model.summary()
noise = Input(shape=self.config.generator_input_dim)
image = model(noise)
return Model(noise,image)
def build_discriminator_model(self):
"""
这是构造判别器模型的函数
:return: 返回判别器模型discriminator_model
"""
model = Sequential()
model.add(Conv2D(64,kernel_size=3,strides=2,input_shape=self.config.discriminator_input_dim,padding='same'))
model.add(LeakyReLU(self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.dropout_prob))
model.add(Conv2D(128,kernel_size=3,strides=2,padding='same'))
model.add(LeakyReLU(self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.dropout_prob))
model.add(Conv2D(256,kernel_size=3,strides=2,padding='same'))
model.add(LeakyReLU(self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.dropout_prob))
model.add(Conv2D(512,kernel_size=3,strides=2,padding='same'))
model.add(LeakyReLU(self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.dropout_prob))
model.add(Flatten())
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.summary()
image = Input(shape=self.config.discriminator_input_dim)
validity = model(image)
return Model(image,validity)
def train(self,k,batch_size=256):
"""
这是DCGAN的训练函数
:param train_generator:训练数据生成器
:param batch_size:小批量样本规模
:param k:训练判别器次数
:return:
"""
(x_train, y_train), (X_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train,axis=3)
for epoch in np.arange(1,self.config.epoch+1):
half_batch = int(batch_size / 2)
g_losses = []
g_accuracy = []
d_losses = []
d_accuracy = []
d_loss = []
d_acc = []
for i in np.arange(k):
# 获取真实图片
idx = np.random.randint(0, x_train.shape[0], half_batch)
batch_real_images = x_train[idx]
# 生成一个batch_size的噪声用于生成图片
batch_noise =truncnorm.rvs(-1,1,size = (half_batch , self.config.generator_input_dim[0]))
batch_gen_images = self.generotor_model.predict(batch_noise)
batch_images = np.concatenate((batch_gen_images,batch_real_images))
# 构造标签
batch_gen_images_labels = truncnorm.rvs(0.0,0.3,size=(half_batch ,1))
batch_real_images_labels = truncnorm.rvs(0.7,1.2,size=(half_batch ,1))
batch_images_labels = np.concatenate((batch_gen_images_labels,batch_real_images_labels))
# 训练判别器
d_result = self.discriminator_model.train_on_batch(batch_images,batch_images_labels)
d_loss.append(d_result[0])
d_acc.append(d_result[1])
d_loss = np.average(d_loss)
d_acc = np.average(d_acc)
# 生成一个batch_size的噪声来训练生成器
batch_noise = truncnorm.rvs(-1,1,size=(half_batch ,self.config.generator_input_dim[0]))
batch_noise_label = truncnorm.rvs(0.7,1.2,size=(half_batch ,1))
g_result = self.dcgan.train_on_batch(batch_noise,batch_noise_label)
g_losses.append(g_result[0])
g_accuracy.append(g_result[1])
d_losses.append(d_loss)
d_accuracy.append(d_acc)
str = "Epoch:%05d,generator_loss:%.5f,generator_acc:%.5f,discriminator_loss:%.5f,discriminator_accuracy%.5f" \
% (epoch, g_result[0],g_result[1],d_loss, d_acc)
print(str)
if epoch % self.config.save_interval == 0:
model_dcgan = "Epoch%05dgenerator_loss%.5fgenerator_accuracy%.5fdiscriminator_loss%.5fdiscriminator_accuracy%.5f.h5" \
% (epoch,np.average(g_losses), np.average(g_accuracy),np.average(d_losses),np.average(d_accuracy))
#self.dcgan.save(os.path.join(self.config.save_weight_dir,'dcgan.h5'))
self.dcgan.save(os.path.join(self.config.save_weight_dir,model_dcgan))
self.save_image(epoch)
def save_image(self,epoch):
"""
这是保存生成图片的函数
:param images: 图片集
:param epoch:周期数
:return:
"""
rows, cols = 5, 5
noise = truncnorm.rvs(-1, 1, size=(rows * cols, self.config.generator_input_dim[0]))
images = self.generotor_model.predict(noise)
fig, axs = plt.subplots(rows, cols)
cnt = 0
for i in range(rows):
for j in range(cols):
axs[i, j].imshow(images[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig(os.path.join(self.config.result_path,"mnist-{0:0>5}.png".format(epoch)), dpi=300)
plt.close()
接下来我们给出,DCGAN训练与生成手写数字的程序,如下所示。在这份代码中,我们首先构造了一个属于mnist数据集的参数配置类MnistConfig,该类继承自基本参数配置类Config。Config的代结构请详见github链接:DCGAN-mnist,在此我们不在具体给给出。
# -*- coding: utf-8 -*-
# @Time : 2019/9/15 21:58
# @Author : DaiPuWei
# @Email : [email protected]
# @Blog : https://daipuweiai.blog.csdn.net/
# @File : main.py
# @Software: PyCharm
import os
import datetime
from Config.Config import Config
from DCGAN.DCGAN import DCGAN
class MnistConfig(Config):
def __init__(self):
#super(Config, self).__init__()
Config.__init__(self)
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_weight_dir = os.path.join(Config.get_save_weight_dir(self),time)
if not os.path.exists(self.save_weight_dir):
os.mkdir(self.save_weight_dir)
self.result_path = os.path.join(Config.get_result_path(self),time)
if not os.path.exists(self.result_path):
os.mkdir(self.result_path)
self.batch_size = 256
def run_main():
"""
这是主函数
"""
cfg = MnistConfig()
dcgan = DCGAN(cfg)
dcgan.train(20,cfg.batch_size)
if __name__ == '__main__':
run_main()
CNN的训练过程主要就是根据损失函数利用梯度下降及其改进算法进行训练,更新网络参数。但是不同于CNN的训练,GAN的训练是一个动态的过程,GAN的目标是寻求判别器与生成器之间的动态平衡。因此我们不能只靠梯度下降算法进行训练模型。
在DCGAN训练过程中有如下几点小技巧可以直接采纳:
至此,GAN系列的第一篇到此完全结束。在这一篇文章中,我们领略了DCGAN的强大。接下来我们原始GAN开始进行讲解GAN发展。敬请期待GAN系列第二篇,GAN论文详解。