import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow import keras
from tensorflow.keras import datasets, layers
def preprocess(x):
"""预处理函数"""
x = tf.cast(x, tf.float32) / 255.0
return x
learning_rate = 0.001
batch_size = 100
h_shape = 512
z_shape = 20
image_shape = 28*28
epochs = 20
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(1000).batch(batch_size).map(preprocess)
class VAE_Model(keras.Model):
"""VAE网络模型类"""
def __init__(self):
super(VAE_Model, self).__init__()
self.fc1 = layers.Dense(h_shape, activation='relu')
self.fc2 = layers.Dense(z_shape, activation=None)
self.fc3 = layers.Dense(z_shape, activation=None)
self.fc4 = layers.Dense(h_shape, activation='relu')
self.fc5 = layers.Dense(image_shape, activation=None)
def encoder(self, x):
"""编码网络"""
out = self.fc1(x)
mean = self.fc2(out)
log_var = self.fc3(out)
return mean, log_var
def decoder(self, z):
"""译码网络"""
out = self.fc4(z)
out = self.fc5(out)
return out
def reparameterize(self, mean, log_var):
"""reparamaterize trick"""
epsion = tf.random.normal(log_var.shape)
std_var = tf.exp(log_var) ** 0.5
z = mean + std_var * epsion
return z
def call(self, x, training=None):
"""__call__方法"""
mean, log_var = self.encoder(x)
z = self.reparameterize(mean, log_var)
out = self.decoder(z)
return out, mean, log_var
def save_figure(image_data, image_name):
"""保存图片阵列函数"""
figure = Image.new(mode='L', size=(280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
image = image_data[index]
image = Image.fromarray(image, mode='L')
figure.paste(image, (i, j))
index += 1
figure.save(image_name)
def main():
"""main function"""
optimizer = keras.optimizers.Adam(learning_rate)
vae_model = VAE_Model()
vae_model.build(input_shape=(batch_size, image_shape))
vae_model.summary()
for epoch in range(epochs):
for step, x_data in enumerate(train_db):
x = tf.reshape(x_data, [-1, image_shape])
with tf.GradientTape() as tape:
out, mean, log_var = vae_model(x)
crossentropy_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=out), axis=1)
crossentropy_meanloss = tf.reduce_mean(crossentropy_loss)
kl_loss = -0.5 * tf.reduce_sum((log_var + 1 - mean ** 2 - tf.exp(log_var)), axis=1)
kl_meanloss = tf.reduce_mean(kl_loss)
loss = crossentropy_meanloss + kl_meanloss
grads = tape.gradient(loss, vae_model.trainable_variables)
optimizer.apply_gradients(zip(grads, vae_model.trainable_variables))
if step % 100 == 0:
print("epoch[{}/{}] step[{}/{}] reconstruct loss: {:.4f} kl loss: {:.4f} crossentropy loss: {:.4f}"
.format(epoch+1, epochs, step, int(x_train.shape[0] / batch_size), loss, kl_meanloss,
crossentropy_meanloss))
x_reshaped = tf.nn.sigmoid(out)
x_reshaped = tf.reshape(x_reshaped, [-1, 28, 28])
x_concat = tf.concat([x_data[:50], x_reshaped[:50]], axis=0)
x_concat = np.array(x_concat) * 255
x_concat = x_concat.astype(np.uint8)
save_figure(x_concat, 'images\\image_epoch_%d.png' % (epoch+1))
print("Epoch {} image has saved in the path: images\\image_epoch_{}.png".format(epoch+1, epoch+1))
vae_model.save_weights('model\\VEA_Model_epoch_%d.ckpt' % (epoch+1))
print("Epoch {} model has saved in the path: \\model".format(epoch+1))
z_noise = tf.random.normal(shape=[batch_size, z_shape])
noise_recons = vae_model.decoder(z_noise)
noise_recons = tf.nn.sigmoid(noise_recons)
noise_recons = tf.reshape(noise_recons, [-1, 28, 28])
noise_recons = np.array(noise_recons) * 255
noise_recons = noise_recons.astype(np.uint8)
save_figure(noise_recons, 'images\\noise_recons_epoch_%d.png' % (epoch+1))
print("Epoch {} image of noise reconstructing has saved in the path: images\\noise_recons_epoch_{}.png"
.format(epoch + 1, epoch + 1))
if __name__ == '__main__':
main()
Model: "vae__model"
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
dense (Dense) multiple 401920
_________________________________________________________________
dense_1 (Dense) multiple 10260
_________________________________________________________________
dense_2 (Dense) multiple 10260
_________________________________________________________________
dense_3 (Dense) multiple 10752
_________________________________________________________________
dense_4 (Dense) multiple 402192
=================================================================
Total params: 835,384
Trainable params: 835,384
Non-trainable params: 0
_________________________________________________________________
epoch[1/20] step[0/600] reconstruct loss: 550.5972 kl loss: 2.8135 crossentropy loss: 547.7836
epoch[1/20] step[100/600] reconstruct loss: 300.4324 kl loss: 14.1489 crossentropy loss: 286.2836
epoch[1/20] step[200/600] reconstruct loss: 269.3903 kl loss: 16.1698 crossentropy loss: 253.2205
epoch[1/20] step[300/600] reconstruct loss: 266.6666 kl loss: 15.6042 crossentropy loss: 251.0624
epoch[1/20] step[400/600] reconstruct loss: 267.5359 kl loss: 16.2817 crossentropy loss: 251.2542
epoch[1/20] step[500/600] reconstruct loss: 269.5238 kl loss: 16.0071 crossentropy loss: 253.5167
Epoch 1 image has saved in the path: images\image_epoch_1.png
Epoch 1 model has saved in the path: model
Epoch 1 image of noise reconstructing has saved in the path: images\noise_recons_epoch_1.png
epoch[2/20] step[0/600] reconstruct loss: 245.8026 kl loss: 16.4432 crossentropy loss: 229.3594
epoch[2/20] step[100/600] reconstruct loss: 258.3123 kl loss: 16.3746 crossentropy loss: 241.9377
epoch[2/20] step[200/600] reconstruct loss: 252.6373 kl loss: 16.7064 crossentropy loss: 235.9309
epoch[2/20] step[300/600] reconstruct loss: 240.0333 kl loss: 16.5543 crossentropy loss: 223.4790
epoch[2/20] step[400/600] reconstruct loss: 257.7153 kl loss: 16.2433 crossentropy loss: 241.4720
epoch[2/20] step[500/600] reconstruct loss: 255.9624 kl loss: 16.3426 crossentropy loss: 239.6198
Epoch 2 image has saved in the path: images\image_epoch_2.png
Epoch 2 model has saved in the path: model
Epoch 2 image of noise reconstructing has saved in the path: images\noise_recons_epoch_2.png
epoch[3/20] step[0/600] reconstruct loss: 251.1254 kl loss: 16.1641 crossentropy loss: 234.9614
epoch[3/20] step[100/600] reconstruct loss: 245.0857 kl loss: 17.1267 crossentropy loss: 227.9590
epoch[3/20] step[200/600] reconstruct loss: 245.2651 kl loss: 16.2374 crossentropy loss: 229.0277
epoch[3/20] step[300/600] reconstruct loss: 252.0008 kl loss: 16.3147 crossentropy loss: 235.6861
epoch[3/20] step[400/600] reconstruct loss: 260.2074 kl loss: 16.5105 crossentropy loss: 243.6969
epoch[3/20] step[500/600] reconstruct loss: 237.4382 kl loss: 15.9900 crossentropy loss: 221.4482
Epoch 3 image has saved in the path: images\image_epoch_3.png
Epoch 3 model has saved in the path: model
Epoch 3 image of noise reconstructing has saved in the path: images\noise_recons_epoch_3.png
epoch[4/20] step[0/600] reconstruct loss: 246.2724 kl loss: 16.1399 crossentropy loss: 230.1325
epoch[4/20] step[100/600] reconstruct loss: 242.9304 kl loss: 15.9318 crossentropy loss: 226.9986
epoch[4/20] step[200/600] reconstruct loss: 247.8724 kl loss: 16.4313 crossentropy loss: 231.4411
epoch[4/20] step[300/600] reconstruct loss: 244.6034 kl loss: 16.1179 crossentropy loss: 228.4855
epoch[4/20] step[400/600] reconstruct loss: 252.1622 kl loss: 15.8833 crossentropy loss: 236.2789
epoch[4/20] step[500/600] reconstruct loss: 241.3502 kl loss: 16.0111 crossentropy loss: 225.3391
Epoch 4 image has saved in the path: images\image_epoch_4.png
Epoch 4 model has saved in the path: model
Epoch 4 image of noise reconstructing has saved in the path: images\noise_recons_epoch_4.png
epoch[5/20] step[0/600] reconstruct loss: 242.7823 kl loss: 16.1403 crossentropy loss: 226.6420
epoch[5/20] step[100/600] reconstruct loss: 238.2191 kl loss: 15.2946 crossentropy loss: 222.9244
epoch[5/20] step[200/600] reconstruct loss: 263.1482 kl loss: 16.3715 crossentropy loss: 246.7767
epoch[5/20] step[300/600] reconstruct loss: 250.9017 kl loss: 15.9819 crossentropy loss: 234.9198
epoch[5/20] step[400/600] reconstruct loss: 253.9848 kl loss: 16.1483 crossentropy loss: 237.8365
epoch[5/20] step[500/600] reconstruct loss: 238.0354 kl loss: 16.2313 crossentropy loss: 221.8042
Epoch 5 image has saved in the path: images\image_epoch_5.png
Epoch 5 model has saved in the path: model
Epoch 5 image of noise reconstructing has saved in the path: images\noise_recons_epoch_5.png
………………………………
训练至20个epoch的原始图像的重构图像,左5列为原始图像,右5列为重构图像
训练至20个epoch的噪声的生成图像