(本文由《Python深度学习》整理)
图像生成的关键思想是找到一个低维的表示潜在空间(latent space),其中任意点都可以被映射为一张逼真的图像,这种映射模块叫生成器(generator,对于GAN)或解码器(decoder,对于VAE)。
VAE和GAN各自的优缺点:
VAE适合学习具有良好结构的潜在空间(连续性、低维度);
GAN生成的图像逼真,但潜在空间可能没有良好结构。
目的:将输入编码到低维潜在空间 ,再解码回来,使其和原始图像有一样的尺寸。
大致代码:
z_mean, z_log_variance = encoder(input_img)
z = z_mean + exp(0.5 * z_log_variance) *epsilon
reconstructed_img = decoder(z)
model = Model(input_img, reconstructed_img)
具体代码:
#潜在空间采样
#包装到lamda层
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),mean=0., stddev=1.)
return z_mean + K.exp(0.5 *z_log_var) * epsilon
z = layers.Lamda(sampling)([z_mean, z_log_var])
#VAE解码器网络,将潜在空间点映射为图像
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32, 3, padding='same',activation='relu',strides=(2,2))(x)
decoder = Model(decoder_input, x)
z_decoded = decoder(z)
#用于计算VAE损失的自定义层
class CustomVariationalLayer(keras.layers.Layer):
def vae_loss(self, x, z_decoded):
x = K.flatten(x)
z_decoded = K.flatten(z_decoded)
xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean)-K.exp(z_log_var), axis=-1)
return K.mean(xent_loss + kl_loss)
def call(self, inputs):
x = inputs[0]
z_decoded = inputs[1]
loss = self.vae_loss(x, z_decoded)
self.add_loss(loss, inputs=inputs)
return x
y = CustomVariationalLayer()([input_img, z_decoded])
#训练VAE(在MNIST训练)
from keras.datasets import mnist
vae = Model(input_img, y)
vae.compile(optimizer='rmsprop',loss=None)
vae.summary()
(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')/255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32')/255.
x_test = x_test.reshape(x_test.shape + (1,))
vae.fit(x=x_train, y=None, shuffle=True, epochs=10, batch_size=batch_size, validation_data=(x_test, None))
#使用训练的网络,从二维潜在空间采取一组点的网格,将其解码为图像
import matplotlib.pyplot as plt
from scipy.stats import norm
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([xi, yi])
z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
x_decoded = decoder.predict(z_sample, batch_size=batch_size)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i*digit_size: (i+1)*digit_size, j*digit_size: (j+1)*digit_size] = digit
plt.figure(figsize=(10,10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
生成器网络(generator network):以一个随机向量(潜在空间的一个点)作为输入,将其解码为一张合成图像。
判别器网络(discriminator network):以一张图像(真实的或合成的)作为输入,预测来自训练集还是生成网络。
具体代码:
#生成器
import keras
from keras import layers
import numpy as np
latent_dim = 32
height = 32
width = 32
channels = 3
generator_input = keras.Input(shape=(latent_dim,))
#将输入转换为大小16×16的128个通道的特征图
x = layers.Dense(128 * 16* 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
#上采样为32×32
x = layers.Conv2DTranspose(256, 4, stride=2, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)
#将生成器模型实例化,它将形状为(latent_dim,)的输入映射到形状为(32, 32, 3)的图像
generator = keras.models.Model(generator_input, x)
generator.summary()
#判别器
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(1, activation='sigmoid')(x) #分类层
#将判别器实例化,它将形状为(32,32,3)的输入转换为一个二进制的分类决策(真/假)
discriminator = keras.models.Model(discriminator_input, x)
discriminator.summary()
#在优化器中使用梯度裁剪来限制梯度值的范围;使用学习率衰减来稳定系列过程
discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')
#对抗网络(将潜在空间的点转换为一个分类决策,需要将判别器设置为冻结)
discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')
#训练DCGAN
说明:训练过程每轮都进行如下操作
import os
from keras.preprocessing import image
(x_train,y_train),(_, _) = keras.datasets.cifar10.load_data()
x_train = x_train[y_train.flatten()==6] #选择青蛙图像编号为6
x_train = x_train.reshape((x_train.shape[0],)+(height, width, channels).astype('float32')/255.#数据标准化
iterations = 10000
batch_size = 20
save_dir = 'your_dir'
start = 0
for step in range(iterations):
#在潜在空间随机采样
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
#将这些点解码为虚假图像
generated_images = generator.predict(random_latent_vectors)
#将这些虚假图像和真实图像合在一起
stop = start + batch_size
real_images = x_train[start:stop]
combined_images = np.concatenate([generated_images, real_images])
labels = np.concatenate([np.ones((batch_size, 1)),np.zeros((batch_size,1))])
#向标签中添加随机噪声
label += 0.05 * np.random.random(labels.shape)
#训练判别器
d_loss = discriminator.train_on_batch(combined_images,labels)
#在潜在空间中采样随机点
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
#合并标签,假装全是真实图像
misleading_targets = np.zeros((batch_size,1))
#通过GAN模型训练生成器(冻结判别器权重)
a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
start += batch_size
if start > len(x_train) - batch_size:
start = 0
if start % 100 == 0:
gan.save_weights('gan.h5') #保存权重模型
print('discriminator loss:', d_loss)
print('adversarial loss:', a_loss)
img = image.array_to_img(generated_image[0] * 255., scale = False)
img.save(os.path.join(save_dir, 'generated_frog'+ str(step)+'.png'))
img = image.array_to_img(real_image[0] * 255., scale = False)
img.save(os.path.join(save_dir, 'real_frog'+ str(step) +'.png'))
【训练技巧】