GAN一般有两个内容,一是生成器(generator),二是辨别器(discriminator)。
辨别器的目的是:尽可能地分辨输入的数据是生成器生成的假数据还是真实的数据
生成器的目的是:尽可能地骗过辨别器,使得辨别器认为它生成的数据是真实的数据
这是个博弈的过程,能够使得生成器和辨别器不断成长,最后生成器能够生成以假乱真的数据
其中生成器的输入是随机向量,输出是指定的数据
鉴别器的输入是数据,输出的是0到1之间的数(意味着数据是真实的数据的概率)
本博客使用的代码是在tensorflow2.0.0基础上进行的,主要使用keras
1、导入tensorflow模块
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt
tensorflow内的datasets中有mnist手写数据集
keras.layers中有能够直接使用的层
keras.losses中是损失函数
2、载入数据,并做预处理
(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5
BATCH_SIZE = 256
BUFFER_SIZE = 60000
datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
数据本来在[0,255]之间,将其归一化到[-1,1]之间,并且reshape多加一个通道维度,最后重构一个数据集
3、定义生成器模型
def generator_model():
model = keras.Sequential()
model.add(Dense(256,input_shape=(100,),use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(512,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(784,use_bias=False,activation='tanh'))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((28,28,1)))
return model
生成器的输入是一个随机的100维向量。
生成器模型由三个全连接层构成,最后一个是输出层,因为要输出28x28的数据,所以最后一个全连接层有784个神经元,并且经过激活函数之后,reshape成为一张图片28x28x1,tanh激活函数能够使得生成的数据在[-1,1]之间
4、定义辨别器模型
def discriminator_model():
model = keras.Sequential()
model.add(Flatten())
model.add(Dense(512,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(256,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(1))
return model
辨别器由一个平坦层、三个全连接层构成,其中最后一个全连接层只有一个神经元,目的是为了让其输出一个概率
5、定义损失函数和优化器
cross_entropy = BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_out,fake_out):
real_loss = cross_entropy(tf.ones_like(real_out),real_out)
fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss + fake_loss
def generator_loss(fake_out):
return cross_entropy(tf.ones_like(fake_out),fake_out)
generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)
EPOCHS = 50
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()
其中real_out的意思是向辨别器输入真实图片后,辨别器的输出,fake_out的意思是向辨别器输入假图片后,辨别器的输出
6、定义训练步骤
def train_step(images):
noise = tf.random.normal([BATCH_SIZE,noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
real_out = discriminator(images,training=True)
gen_image = generator(noise,training=True)
fake_out = discriminator(gen_image,training=True)
gen_loss = generator_loss(fake_out)
dis_loss = discriminator_loss(real_out,fake_out)
gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
return gen_loss,dis_loss
7、定义画图函数
def generate_plot_image(gen_model,test_noise,epoch):
pre_images = gen_model(test_noise,training=False)
fig = plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
plt.axis('off')
plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
plt.close()
8、开始训练
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
gen_loss,dis_loss = train_step(image_batch)
print('the ',epoch+1,' epochs have trained')
print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
generate_plot_image(generator,seed,epoch)
print('finished')
train(datasets,EPOCHS)
训练大概个位数的epoch后就会隐约能够看见手写数字了
训练50个epoch后的训练结果如下所示
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt
(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5
BATCH_SIZE = 256
BUFFER_SIZE = 60000
datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
def generator_model():
model = keras.Sequential()
model.add(Dense(256,input_shape=(100,),use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(512,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(784,use_bias=False,activation='tanh'))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((28,28,1)))
return model
def discriminator_model():
model = keras.Sequential()
model.add(Flatten())
model.add(Dense(512,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(256,use_bias=False))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dense(1))
return model
cross_entropy = BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_out,fake_out):
real_loss = cross_entropy(tf.ones_like(real_out),real_out)
fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss + fake_loss
def generator_loss(fake_out):
return cross_entropy(tf.ones_like(fake_out),fake_out)
generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)
EPOCHS = 50
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()
def train_step(images):
noise = tf.random.normal([BATCH_SIZE,noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
real_out = discriminator(images,training=True)
gen_image = generator(noise,training=True)
fake_out = discriminator(gen_image,training=True)
gen_loss = generator_loss(fake_out)
dis_loss = discriminator_loss(real_out,fake_out)
gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
return gen_loss,dis_loss
def generate_plot_image(gen_model,test_noise,epoch):
pre_images = gen_model(test_noise,training=False)
fig = plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
plt.axis('off')
plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
plt.close()
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
gen_loss,dis_loss = train_step(image_batch)
print('the ',epoch+1,' epochs have trained')
print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
generate_plot_image(generator,seed,epoch)
print('finished')
train(datasets,EPOCHS)