变分自编码器(VAE)的基础知识参考博文变分自编码器(VAE)原理与实现(tensorflow2.x)。作为VAE的应用,我们将使用VAE生成一些可控制属性的人脸图片。可用的人脸数据集包括:
网络的输入形状为(112,112,3),可以在数据预处理时调整图片尺寸。由于数据集较复杂,可以增加滤波器数量以增加网络容量。因此,编码器中的卷积层如下:
a) Conv2D(filters = 32, kernel_size=(3,3), strides = 2)
b) Conv2D(filters = 32, kernel_size=(3,3), strides = 2)
c) Conv2D(filters = 64, kernel_size=(3,3), strides = 2)
d) Conv2D(filters = 64, kernel_size=(3,3), strides = 2)
让我们先看一下VAE的重构图像效果:
尽管重建的图片并不完美,但它们至少看起来不错。 VAE设法从输入图像中学习了一些特征,并使用它们来绘制新的面孔。可以看出,VAE可以更好地重建女性面孔。这是由于Celeb_A数据集
中女性的比例较高。这也就是为什么男性的肤色更趋向年轻、女性化。
观察图像背景,由于图像背景的多样性,因此编码器无法将每个细节编码至低维度,因此我们可以看到VAE对背景颜色进行编码,而解码器则基于这些颜色创建模糊的背景。
为了生成新图像,我们从标准的高斯分布中采样随机数,并将传递给解码器:
z_samples = np.random.normal(loc=0., scale=1, size=(image_num, z_dim))
images = vae.decoder(z_samples.astype(np.float32))
但,某些生成的面孔看起来太恐怖了!
我们可以使用采样技巧来提高图像保真度。
可以看到,训练后的VAE可以很好地重建人脸。但,随机抽样潜变量生成的图像中存在问题。为了调试该问题,将数据集中图像输入到VAE解码器中,以获取潜在空间的均值和方差。然后,绘制了每个潜在空间变量的均值:
从理论上讲,它们应该以0为均值且方差为1,但随机采样的样本并不总是与解码器期望的分布匹配。这是采样技巧技巧的地方,收集潜在变量的平均标准差(一个标量值),该标准差用于生成正态分布的样本(200维)。然后,在其中添加了平均均值(200个维度)。
yep,现在生成的图片看起来好多了!
接下来,将介绍如何进行面部属性编辑,而不是生成随机的面孔。
本质上,潜在空间意味着潜在变量的每个可能值。在我们的VAE中,它是200个维度的向量(或者称200个变量)。我们希望每个变量都包含独特的语义,例如z[0]代表眼睛,z[1]代表眼睛的颜色,依此类推,事情从来没有那么简单。假设信息是在所有潜在向量中编码的,就可以使用向量算术探索潜在空间。
使用一个二维示例解释属性控制的原理。假设现在在地图上的(0,0)点,而目的地位于(x, y)。因此,朝目的地的方向是(x-0, y-0)除以(x, y)的L2范数,可以将方向表示为(x_dot, y_dot)。因此,每次移动(x_dot, y_dot)时,都在朝着目的地移动。每次移动(-2 * x_dot, -2 * y_dot)时,将以两倍的步幅远离目的地。
类似的,如果我们知道了微笑属性的方向向量,则可以将其添加到潜在变量中以使人脸附加微笑属性:
new_z_samples = z_samples + smiling_magnitude*smiling_vector
smile_magnitude是我们设置的标量值,因此下一步是找出获取smile_vector的方法。
Celeb A数据集附带每个图像的面部属性注释。标签是二进制的,指示图像中是否存在特定属性。我们将使用标签和编码的潜在变量来找到我们的方向向量:
def preprocess_attrib(sample, attribute):
image = sample['image']
image = tf.image.resize(image, [112,112])
image = tf.cast(image, tf.float32)/255.
return image, sample['attributes'][attribute]
ds = ds.map(lambda x: preprocess_attrib(x, attribute))
提取属性向量后,进行以下操作:
下图显示了通过内插潜在向量生成的图像:
接下来,我们可以尝试一起更改多个面部属性。在下图中,左侧的图像是随机生成的,并用作基准。右侧是经过一些潜在空间运算后的新图像:
该小部件可在Jupyter notebook中使用。
# vae_faces.ipynb
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Layer, Input, Conv2D, Dense, Flatten, Reshape, Lambda, Dropout
from tensorflow.keras.layers import Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalization
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow_datasets as tfds
import cv2
import numpy as np
import matplotlib.pyplot as plt
import datetime, os
import warnings
warnings.filterwarnings('ignore')
print("Tensorflow", tf.__version__)
strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
print('Number of devidex: {}'.format(num_devices))
(ds_train, ds_test_), ds_info = tfds.load(
'celeb_a',
split=['train', 'test'],
shuffle_files=True,
with_info=True)
fig = tfds.show_examples(ds_train, ds_info)
batch_size = 32 * num_devices
def preprocess(sample):
image = sample['image']
image = tf.image.resize(image, [112,112])
image = tf.cast(image, tf.float32) / 255.
return image, image
ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(128)
ds_train = ds_train.batch(batch_size, drop_remainder=True).prefetch(batch_size)
ds_test = ds_test_.map(preprocess).batch(batch_size, drop_remainder=True).prefetch(batch_size)
train_num = ds_info.splits['train'].num_examples
test_num = ds_info.splits['test'].num_examples
class GaussianSampling(Layer):
def call(self, inputs):
means, logvar = inputs
epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)
samples = means + tf.exp(0.5 * logvar) * epsilon
return samples
class DownConvBlock(Layer):
count = 0
def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):
super(DownConvBlock, self).__init__(name=f"DownConvBlock_{DownConvBlock.count}")
DownConvBlock.count += 1
self.forward = Sequential([
Conv2D(filters, kernel_size, strides, padding),
BatchNormalization(),
LeakyReLU(0.2)
])
def call(self, inputs):
return self.forward(inputs)
class UpConvBlock(Layer):
count = 0
def __init__(self, filters, kernel_size=(3,3), padding='same'):
super(UpConvBlock, self).__init__(name=f"UpConvBlock_{UpConvBlock.count}")
UpConvBlock.count += 1
self.forward = Sequential([
Conv2D(filters, kernel_size, 1, padding),
LeakyReLU(0.2),
UpSampling2D((2,2))
])
def call(self, inputs):
return self.forward(inputs)
class Encoder(Layer):
def __init__(self, z_dim, name='encoder'):
super(Encoder, self).__init__(name=name)
self.features_extract = Sequential([
DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
DownConvBlock(filters=32, kernel_size=(3,3), strides=2),
DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
DownConvBlock(filters=64, kernel_size=(3,3), strides=2),
Flatten()
])
self.dense_mean = Dense(z_dim, name='mean')
self.dense_logvar = Dense(z_dim, name='logvar')
self.sampler = GaussianSampling()
def call(self, inputs):
x = self.features_extract(inputs)
mean = self.dense_mean(x)
logvar = self.dense_logvar(x)
z = self.sampler([mean, logvar])
return z, mean, logvar
class Decoder(Layer):
def __init__(self, z_dim, name='decoder'):
super(Decoder, self).__init__(name=name)
self.forward = Sequential([
Dense(7*7*64, activation='relu'),
Reshape((7,7,64)),
UpConvBlock(filters=64, kernel_size=(3,3)),
UpConvBlock(filters=64, kernel_size=(3,3)),
UpConvBlock(filters=32, kernel_size=(3,3)),
UpConvBlock(filters=32, kernel_size=(3,3)),
Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')
])
def call(self, inputs):
return self.forward(inputs)
class VAE(Model):
def __init__(self, z_dim, name='VAE'):
super(VAE, self).__init__(name=name)
self.encoder = Encoder(z_dim)
self.decoder = Decoder(z_dim)
self.mean = None
self.logvar = None
def call(self, inputs):
z, self.mean, self.logvar = self.encoder(inputs)
out = self.decoder(z)
return out
if num_devices > 1:
with strategy.scope():
vae = VAE(z_dim=200)
else:
vae = VAE(z_dim=200)
def vae_kl_loss(y_true, y_pred):
kl_loss = -0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))
return kl_loss
def vae_rc_loss(y_true, y_pred):
rc_loss = tf.keras.losses.MSE(y_true, y_pred)
return rc_loss
def vae_loss(y_true, y_pred):
kl_loss = vae_kl_loss(y_true, y_pred)
rc_loss = vae_rc_loss(y_true, y_pred)
kl_weight_const = 0.01
return kl_weight_const * kl_loss + rc_loss
model_path = "vae_faces_cele_a.h5"
checkpoint = ModelCheckpoint(
model_path,
monitor='vae_rc_loss',
verbose=1,
save_best_only=True,
mode='auto',
save_weights_only=True
)
early = EarlyStopping(
monitor='vae_rc_loss',
mode='auto',
patience=3
)
callbacks_list = [checkpoint, early]
initial_learning_rate = 1e-3
steps_per_epoch = int(np.round(train_num/batch_size))
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=steps_per_epoch,
decay_rate=0.96,
staircase=True
)
vae.compile(
loss=[vae_loss],
optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),
metrics=[vae_kl_loss, vae_rc_loss]
)
history = vae.fit(ds_train, validation_data=ds_test,epochs=50,callbacks=callbacks_list)
images, labels = next(iter(ds_train))
vae.load_weights(model_path)
outputs = vae.predict(images)
# Display
grid_col = 8
grid_row = 2
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2, grid_row*2))
i = 0
for row in range(0, grid_row, 2):
for col in range(grid_col):
axarr[row,col].imshow(images[i])
axarr[row,col].axis('off')
axarr[row+1,col].imshow(outputs[i])
axarr[row+1,col].axis('off')
i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()
avg_z_mean = []
avg_z_std = []
for i in range(steps_per_epoch):
images, labels = next(iter(ds_train))
z, z_mean, z_logvar = vae.encoder(images)
avg_z_mean.append(np.mean(z_mean, axis=0))
avg_z_std.append(np.mean(np.exp(0.5*z_logvar),axis=0))
avg_z_mean = np.mean(avg_z_mean, axis=0)
avg_z_std = np.mean(avg_z_std, axis=0)
plt.plot(avg_z_mean)
plt.ylabel("Average z mean")
plt.xlabel("z dimension")
grid_col = 10
grid_row = 10
f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col, 1.5*grid_row))
i = 0
for row in range(grid_row):
for col in range(grid_col):
axarr[row, col].hist(z[:,i], bins=20)
# axarr[row, col].axis('off')
i += 1
#f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()
z_dim = 200
z_samples = np.random.normal(loc=0, scale=1, size=(25, z_dim))
images = vae.decoder(z_samples.astype(np.float32))
grid_col = 7
grid_row = 2
f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col,2*grid_row))
i = 0
for row in range(grid_row):
for col in range(grid_col):
axarr[row, col].imshow(images[i])
axarr[row, col].axis('off')
i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()
# 采样技巧
z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(25, z_dim))
z_samples += avg_z_mean
images = vae.decoder(z_samples.astype(np.float32))
grid_col = 7
grid_row = 2
f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))
i = 0
for row in range(grid_row):
for col in range(grid_col):
axarr[row,col].imshow(images[i])
axarr[row,col].axis('off')
i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()
(ds_train, ds_test), ds_info = tfds.load(
'celeb_a',
split=['train', 'test'],
shuffle_files=True,
with_info=True)
test_num = ds_info.splits['test'].num_examples
def preprocess_attrib(sample, attribute):
image = sample['image']
image = tf.image.resize(image, [112, 112])
image = tf.cast(image, tf.float32) / 255.
return image, sample['attributes'][attribute]
def extract_attrib_vector(attribute, ds):
batch_size = 32 * num_devices
ds = ds.map(lambda x: preprocess_attrib(x, attribute))
ds = ds.batch(batch_size)
steps_per_epoch = int(np.round(test_num / batch_size))
pos_z = []
pos_z_num = []
neg_z = []
neg_z_num = []
for i in range(steps_per_epoch):
images, labels = next(iter(ds))
z, z_mean, z_logvar = vae.encoder(images)
z = z.numpy()
step_pos_z = z[labels==True]
pos_z.append(np.mean(step_pos_z, axis=0))
pos_z_num.append(step_pos_z.shape[0])
step_neg_z = z[labels==False]
neg_z.append(np.mean(step_neg_z, axis=0))
neg_z_num.append(step_neg_z.shape[0])
avg_pos_z = np.average(pos_z, axis=(0), weights=pos_z_num)
avg_neg_z = np.average(neg_z, axis=(0), weights=neg_z_num)
attrib_vector = avg_pos_z - avg_neg_z
return attrib_vector
attributes = list(ds_info.features['attributes'].keys())
attribs_vectors = {
}
for attrib in attributes:
print(attrib)
attribs_vectors[attrib] = extract_attrib_vector(attrib, ds_test)
def explore_latent_variable(image, attrib):
grid_col = 8
grid_row = 1
z_samples,_,_ = vae.encoder(tf.expand_dims(image,0))
f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))
i = 0
row = 0
step = -3
axarr[0].imshow(image)
axarr[0].axis('off')
for col in range(1, grid_col):
new_z_samples = z_samples + step * attribs_vectors[attrib]
reconstructed_image = vae.decoder(new_z_samples)
step += 1
axarr[col].imshow(reconstructed_image[0])
axarr[col].axis('off')
i += 1
f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)
plt.show()
ds_test1 = ds_test.map(preprocess).batch(100)
images, labels = next(iter(ds_test1))
# 控制属性向量生成人脸图片
explore_latent_variable(images[34], 'Male')
explore_latent_variable(images[20], 'Eyeglasses')
explore_latent_variable(images[38], "Chubby")
fname = ""
if fname:
# using existing image from file
image = cv2.imread(fname)
image = image[:,:,::-1]
# crop
min_dim = min(h, w)
h_gap = (h-min_dim) // 2
w_gap = (w-min_dim) // 2
image = image[h_gap:h-h_gap, w_gap,w-w_gap, :]
image = cv2.resize(image, (112,112))
plt.imshow(image)
# encode
input_tensor = np.expand_dims(image, 0)
input_tensor = input_tensor.astype(np.float32) / 255.
z_samples = vae.encoder(input_tensor)
else:
# start with random image
z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(1, 200))
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
@interact
def explore_latent_variable(Male = (-5,5,0.1),
Eyeglasses = (-5,5,0.1),
Young = (-5,5,0.1),
Smiling = (-5,5,0.1),
Blond_Hair = (-5,5,0.1),
Pale_Skin = (-5,5,0.1),
Mustache = (-5,5,0.1)):
new_z_samples = z_samples + \
Male*attribs_vectors['Male'] + \
Eyeglasses*attribs_vectors['Eyeglasses'] +\
Young*attribs_vectors['Young'] +\
Smiling*attribs_vectors['Smiling']+\
Blond_Hair*attribs_vectors['Blond_Hair'] +\
Pale_Skin*attribs_vectors['Pale_Skin'] +\
Mustache*attribs_vectors['Mustache']
images = vae.decoder(new_z_samples)
plt.figure(figsize=(4,4))
plt.axis('off')
plt.imshow(images[0])