自从10月15号在广州的实习结束后,这将近1个月的时间由于学校各种实习相关手续、答辩和赶上毕业论文开题的节奏等原因,因此相关实习结束之前相关笔记没有及时。从今天开始,将恢复相关博客的更新。
在之前我们介绍了DCGAN与原始GAN的相关理论,并给出了DCGAN生成手写数字图像的代码。若有兴趣请分别移步如下链接:
本篇博客我们将介绍CGAN(条件GAN)论文的相关细节。CGAN的论文网址请移步:Conditional Generative Adversarial Nets 。CGAN生成手写数字的keras代码请移步:CGAN-mnist
为了兼顾CGAN的相关理论介绍,我们首先回顾GAN相关细节。GAN主要包括两个网络,一个是生成器 G G G和判别器 D D D,生成器的目的就是将随机输入的高斯噪声映射成图像(“假图”),判别器则是判断输入图像是否来自生成器的概率,即判断输入图像是否为假图的概率。
在这里我们假设数据为 x x x,生成器的数据分布为 p g p_g pg,噪声分布为 p z ( z ) p_z(z) pz(z),那么噪声 z z z的结果可以记作 G ( z ; θ g ) G(z;\theta_g) G(z;θg),数据 x x x在判别器 D D D上的结果为 D ( x ; θ d ) D(x;\theta_d) D(x;θd)。
那么GAN的目的就是无中生有,以假乱真。即要使得生成器 G G G生成的所谓的"假图"骗过判别器 D D D,那么最优状态就是生成器 G G G生成的所谓的"假图"在判别器 D D D的判别结果为0.5,不知道到底是真图还是假图。GAN的目标函数如下:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p d a t a ( z ) [ log ( 1 − D ( G ( z ) ) ) ] (1) \underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={{\mathbb{E}}_{x\sim {{p}_{data}}(x)}}[\log D(x)]+{{\mathbb{E}}_{z\sim {{p}_{data}}(z)}}[\log (1-D(G(z)))]\tag1 GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pdata(z)[log(1−D(G(z)))](1)
在介绍CGAN的原理接下来介绍了CGAN的相关原理。原始的GAN的生成器只能根据随机噪声进行生成图像,至于这个图像是什么(即标签是什么我们无从得知),判别器也只能接收图像输入进行判别是否图像来使生成器。因此CGAN的主要贡献就是在原始GAN的生成器与判别器中的输入中加入额外信息 y y y。额外信息 y y y可以是任何信息,例如标签。因此CGAN的提出使得GAN可以利用图像与对应的标签进行训练,并在测试阶段
利用给定标签生成特定图像。
在CGAN的论文中,网络架构使用的MLP(全连接网络)。在CGAN中的生成器,我们给定一个输入噪声 p z ( z ) p_z(z) pz(z)和额外信息 y y y,之后将两者通过全连接层连接到一起,作为隐藏层输入。同样地,在判别器中输入图像 x x x和 额外信息 y y y也将连接到一起作为隐藏层输入。CGAN的网络架构图如下所示:
那么,CGAN的目标函数可以表述成如下形式:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ∣ y ) ] + E z ∼ p d a t a ( z ) [ log ( 1 − D ( G ( z ∣ y ) ) ) ] (2) \underset{G}{\mathop{\min }}\,\underset{D}{\mathop{\max }}\,V(D,G)={{\mathbb{E}}_{x\sim {{p}_{data}}(x)}}[\log D(x|y)]+{{\mathbb{E}}_{z\sim {{p}_{data}}(z)}}[\log (1-D(G(z|y)))]\tag2 GminDmaxV(D,G)=Ex∼pdata(x)[logD(x∣y)]+Ez∼pdata(z)[log(1−D(G(z∣y)))](2)
下面是CGAN论文中生成的手写数字图像的结果,每一行代表有一个标签,例如第一行代表标签为0的图片。
接下来我们将主要介绍CGAN生成手写数字图像的keras代码。github链接为:CGAN-mnist。首先给出CGAN的网络架构代码:
# -*- coding: utf-8 -*-
# @Time : 2019/10/8 13:39
# @Author : Dai PuWei
# @File : CGAN.py
# @Software: PyCharm
import os
import cv2
import numpy as np
import datetime
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
from keras import Input
from keras import Model
from keras import Sequential
from keras.layers import Dense
from keras.layers import Activation
from keras.layers import Reshape
from keras.layers import Conv2DTranspose
from keras.layers import BatchNormalization
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.merge import multiply
from keras.layers.merge import concatenate
from keras.layers.merge import add
from keras.layers import Embedding
from keras.utils import to_categorical
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar
from copy import deepcopy
from keras.datasets import mnist
def make_trainable(net, val):
""" Freeze or unfreeze layers
"""
net.trainable = val
for l in net.layers: l.trainable = val
class CGAN(object):
def __init__(self,config,weight_path=None):
"""
这是CGAN的初始化函数
:param config: 参数配置类实例
:param weight_path: 权重文件地址,默认为None
"""
self.config = config
self.build_cgan_model()
if weight_path is not None:
self.cgan.load_weights(weight_path,by_name=True)
def build_cgan_model(self):
"""
这是搭建CGAN模型的函数
:return:
"""
# 初始化输入
self.generator_noise_input = Input(shape=(self.config.generator_noise_input_dim,))
self.condational_label_input = Input(shape=(1,), dtype='int32')
self.discriminator_image_input = Input(shape=self.config.discriminator_image_input_dim)
# 定义优化器
self.optimizer = Adam(lr=2e-4, beta_1=0.5)
# 构建生成器模型与判别器模型
self.discriminator_model = self.build_discriminator_model()
self.discriminator_model.compile(optimizer=self.optimizer, loss=['binary_crossentropy'],metrics=['accuracy'])
self.generator_model = self.build_generator()
# 构建CGAN模型
self.discriminator_model.trainable = False
self.cgan_input = [self.generator_noise_input,self.condational_label_input]
generator_output = self.generator_model(self.cgan_input)
cgan_output = self.discriminator_model([generator_output,self.condational_label_input])
self.cgan = Model(self.cgan_input,cgan_output)
# 编译
#self.discriminator_model.compile(optimizer=self.optimizer,loss='binary_crossentropy')
self.cgan.compile(optimizer=self.optimizer,loss=['binary_crossentropy'])
def build_discriminator_model(self):
"""
这是搭建生成器模型的函数
:return:
"""
model = Sequential()
model.add(Dense(512, input_dim=np.prod(self.config.discriminator_image_input_dim)))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(Dense(512))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.LeakyReLU_alpha))
model.add(Dense(512))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(Dropout(self.config.LeakyReLU_alpha))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.config.discriminator_image_input_dim)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.config.condational_label_num,
np.prod(self.config.discriminator_image_input_dim))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input)
return Model([img, label], validity)
def build_generator(self):
"""
这是构建生成器网络的函数
:return:返回生成器模型generotor_model
"""
model = Sequential()
model.add(Dense(256, input_dim=self.config.generator_noise_input_dim))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
model.add(Dense(512))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=self.config.LeakyReLU_alpha))
model.add(BatchNormalization(momentum=self.config.batchnormalization_momentum))
model.add(Dense(np.prod(self.config.discriminator_image_input_dim), activation='tanh'))
model.add(Reshape(self.config.discriminator_image_input_dim))
model.summary()
noise = Input(shape=(self.config.generator_noise_input_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.config.condational_label_num, self.config.generator_noise_input_dim)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
def train(self, train_datagen, epoch, k, batch_size=256):
"""
这是DCGAN的训练函数
:param train_generator:训练数据生成器
:param epoch:周期数
:param batch_size:小批量样本规模
:param k:训练判别器次数
:return:
"""
time =datetime.datetime.now().strftime("%Y%m%d%H%M%S")
model_path = os.path.join(self.config.model_dir,time)
if not os.path.exists(model_path):
os.mkdir(model_path)
train_result_path = os.path.join(self.config.train_result_dir,time)
if not os.path.exists(train_result_path):
os.mkdir(train_result_path)
for ep in np.arange(1, epoch+1).astype(np.int32):
cgan_losses = []
d_losses = []
# 生成进度条
length = train_datagen.batch_num
progbar = Progbar(length)
print('Epoch {}/{}'.format(ep, epoch))
iter = 0
while True:
# 遍历一次全部数据集,那么重新来结束while循环
#print("iter:{},{}".format(iter,train_datagen.get_epoch() != ep))
if train_datagen.epoch != ep:
break
# 获取真实图片,并构造真图对应的标签
batch_real_images, batch_real_labels = train_datagen.next_batch()
batch_real_num_labels = np.ones((batch_size, 1))
#batch_real_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
# 初始化随机噪声,伪造假图,并合并真图和假图数据集
batch_noises = np.random.normal(0, 1, size = (batch_size, self.config.generator_noise_input_dim))
d_loss = []
for i in np.arange(k):
# 构造假图标签,合并真图和假图对应标签
batch_fake_num_labels = np.zeros((batch_size,1))
#batch_fake_num_labels = truncnorm.rvs(0.0, 0.3, size=(batch_size, 1))
batch_fake_labels = deepcopy(batch_real_labels)
batch_fake_images = self.generator_model.predict([batch_noises,batch_fake_labels])
# 训练判别器
real_d_loss = self.discriminator_model.train_on_batch([batch_real_images,batch_real_labels],
batch_real_num_labels)
fake_d_loss = self.discriminator_model.train_on_batch([batch_fake_images, batch_fake_labels],
batch_fake_num_labels)
d_loss.append(list(0.5*np.add(real_d_loss,fake_d_loss)))
#print(d_loss)
d_losses.append(list(np.average(d_loss,0)))
#print(d_losses)
# 生成一个batch_size的噪声来训练生成器
#batch_num_labels = truncnorm.rvs(0.7, 1.2, size=(batch_size, 1))
batch_num_labels = np.ones((batch_size,1))
batch_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
cgan_loss = self.cgan.train_on_batch([batch_noises,batch_labels], batch_num_labels)
cgan_losses.append(cgan_loss)
# 更新进度条
progbar.update(iter, [('dcgan_loss', cgan_losses[iter]),
('discriminator_loss',d_losses[iter][0]),
('acc',d_losses[iter][1])])
#print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (ep, d_losses[ep][0], 100 * d_losses[ep][1],cgan_loss))
iter += 1
if ep % self.config.save_epoch_interval == 0:
model_cgan = "Epoch{}dcgan_loss{}discriminator_loss{}acc{}.h5".format(ep, np.average(cgan_losses),
np.average(d_losses,0)[0],np.average(d_losses,0)[1])
self.cgan.save(os.path.join(model_path, model_cgan))
save_dir = os.path.join(train_result_path, str("Epoch{}".format(ep)))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
self.save_image(int(ep), save_dir)
'''
if int(ep) in self.config.generate_image_interval:
save_dir = os.path.join(train_result_path,str("Epoch{}".format(ep)))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
self.save_image(ep,save_dir)
'''
plt.plot(np.arange(epoch),cgan_losses,'b-','cgan-loss')
plt.plot(np.arange(epoch), d_losses[0], 'b-', 'd-loss')
plt.grid(True)
plt.legend(locs="best")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig(os.path.join(train_result_path,"loss.png"))
def save_image(self, epoch,save_path):
"""
这是保存生成图片的函数
:param epoch:周期数
:param save_path: 图片保存地址
:return:
"""
rows, cols = 10, 10
fig, axs = plt.subplots(rows, cols)
for i in range(rows):
label = np.array([i]*rows).astype(np.int32).reshape(-1,1)
noise = np.random.normal(0, 1, (cols, 100))
images = self.generator_model.predict([noise,label])
images = 127.5*images+127.5
cnt = 0
for j in range(cols):
#img_path = os.path.join(save_path, str(cnt) + ".png")
#cv2.imwrite(img_path, images[cnt])
#axs[i, j].imshow(image.astype(np.int32)[:,:,0])
axs[i, j].imshow(images[cnt,:, :, 0].astype(np.int32), cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig(os.path.join(save_path, "mnist-{}.png".format(epoch)), dpi=600)
plt.close()
def generate_image(self,label):
"""
这是伪造一张图片的函数
:param label:标签
"""
noise = truncnorm.rvs(-1, 1, size=(1, self.config.generator_noise_input_dim))
label = np.array([label]).T
image = self.generator_model.predict([noise,label])[0]
image = 127.5*(image+1)
return image
为了训练我们必须还的构造一个数据集迭代器来读取小批量手写数字图像数据,数据集迭代器类的代码如下:
# -*- coding: utf-8 -*-
# @Time : 2019/10/8 17:29
# @Author : Dai PuWei
# @File : MnistGenerator.py
# @Software: PyCharm
import math
import numpy as np
from keras.datasets import mnist
class MnistGenerator(object):
def __init__(self,batch_size):
"""
这是图像数据生成器的初始化函数
:param batch_size: 小批量样本规模
"""
(x_train,y_train),(x_test,y_test) = mnist.load_data()
#self.x = np.concatenate([x_train,x_test]).astype(np.float32)
self.x = np.expand_dims((x_train.astype(np.float32)-127.5)/127.5,axis=-1)
#self.y = to_categorical(np.concatenate([y_train,y_test]),num_classes=10)
self.y = y_train.reshape(-1,1)
#self.y = self.y[y == ]
#print(np.shape(self.x))
#print(np.shape(self.y))
self.images_size = len(self.x)
random_index = np.random.permutation(np.arange(self.images_size))
self.x = self.x[random_index]
self.y = self.y[random_index]
self.epoch = 1 # 当前迭代次数
self.batch_size = int(batch_size)
self.batch_num = math.ceil(self.images_size / self.batch_size)
self.start = 0
self.end = 0
self.finish_flag = False # 数据集是否遍历完一次标志
def _next_batch(self):
"""
:return:
"""
while True:
#batch_images = np.array([])
#batch_labels = np.array([])
if self.finish_flag: # 数据集遍历完一次
random_index = np.random.permutation(np.arange(self.images_size))
self.x = self.x[random_index]
self.y = self.y[random_index]
self.finish_flag = False
self.epoch += 1
self.end = int(np.min([self.images_size,self.start+self.batch_size]))
batch_images = self.x[self.start:self.end]
batch_labels = self.y[self.start:self.end]
batch_size = self.end - self.start
if self.end == self.images_size: # 数据集刚分均分
self.finish_flag = True
if batch_size < self.batch_size: # 小批次规模小于与预定规模,基本上是最后一组
random_index = np.random.permutation(np.arange(self.images_size))
self.x = self.x[random_index]
self.y = self.y[random_index]
batch_images = np.concatenate((batch_images, self.x[0:self.batch_size - batch_size]))
batch_labels = np.concatenate((batch_labels, self.y[0:self.batch_size - batch_size]))
self.start = self.batch_size - batch_size
self.epoch += 1
else:
self.start = self.end
yield batch_images,batch_labels
def next_batch(self):
datagen = self._next_batch()
return datagen.__next__()
下面是相关训练CGAN的代码:
# -*- coding: utf-8 -*-
# @Time : 2019/10/8 15:43
# @Author : Dai PuWei
# @File : train.py
# @Software: PyCharm
import os
import datetime
from CGAN.CGAN import CGAN
from Config.Config import MnistConfig
from DataGenerator.MnistGenerator import MnistGenerator
def run_main():
"""
这是主函数
"""
cfg = MnistConfig()
cgan = CGAN(cfg)
batch_size = 512
#train_datagen = Cifar10Generator(int(batch_size/2))
train_datagen = MnistGenerator(batch_size)
cgan.train(train_datagen,100000,1,batch_size)
if __name__ == '__main__':
run_main()
下面是训练过程中的CGAN的生成的手写数字图像。第1个epoch之后的生成结果:
第10个epoch之后的生成结果:
第100个epoch之后的生成结果:
第1000个epoch之后的生成结果:
下面是CGAN的测试代码:
# -*- coding: utf-8 -*-
# @Time : 2019/11/8 13:11
# @Author : DaiPuWei
# @Email : [email protected]
# @File : test.py
# @Software: PyCharm
import os
from CGAN.CGAN import CGAN
from Config.Config import MnistConfig
def run_main():
"""
这是主函数
"""
weight_path = os.path.abspath("./model/20191009134644/Epoch1378dcgan_loss1.5952800512313843discriminator_loss[0.49839333 0.7379193 ]acc[0.49839333 0.7379193 ].h5")
result_path = os.path.abspath("./test_result")
if not os.path.exists(result_path):
os.mkdir(result_path)
cfg = MnistConfig()
cgan = CGAN(cfg,weight_path)
cgan.save_image(0,result_path)
if __name__ == '__main__':
run_main()