Tensorflow 2.0实现基于Fashion MNIST数据的VAE网络

# Tensorflow 2.0实现基于Fashion MNIST数据的VAE模型
# code: UTF-8
# coder: Jamin
# date: 2019.12.1

import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow import keras
from tensorflow.keras import datasets, layers


def preprocess(x):
    """预处理函数"""

    x = tf.cast(x, tf.float32) / 255.0
    return x


# 定义超参数
learning_rate = 0.001
batch_size = 100
h_shape = 512
z_shape = 20
image_shape = 28*28
epochs = 20


# 载入数据集
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()

train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(1000).batch(batch_size).map(preprocess)


# 定义VAE网络模型
class VAE_Model(keras.Model):
    """VAE网络模型类"""

    def __init__(self):
        super(VAE_Model, self).__init__()
        self.fc1 = layers.Dense(h_shape, activation='relu')
        self.fc2 = layers.Dense(z_shape, activation=None)
        self.fc3 = layers.Dense(z_shape, activation=None)
        self.fc4 = layers.Dense(h_shape, activation='relu')
        self.fc5 = layers.Dense(image_shape, activation=None)

    def encoder(self, x):
        """编码网络"""

        out = self.fc1(x)
        mean = self.fc2(out)
        log_var = self.fc3(out)

        return mean, log_var

    def decoder(self, z):
        """译码网络"""

        out = self.fc4(z)
        out = self.fc5(out)

        return out

    def reparameterize(self, mean, log_var):
        """reparamaterize trick"""

        epsion = tf.random.normal(log_var.shape)
        std_var = tf.exp(log_var) ** 0.5
        z = mean + std_var * epsion

        return z

    def call(self, x, training=None):
        """__call__方法"""

        mean, log_var = self.encoder(x)
        z = self.reparameterize(mean, log_var)
        out = self.decoder(z)

        return out, mean, log_var


def save_figure(image_data, image_name):
    """保存图片阵列函数"""

    figure = Image.new(mode='L', size=(280, 280))
    index = 0

    # 十行十列阵列
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            image = image_data[index]
            image = Image.fromarray(image, mode='L')
            figure.paste(image, (i, j))
            index += 1
    figure.save(image_name)


def main():
    """main function"""

    optimizer = keras.optimizers.Adam(learning_rate)

    vae_model = VAE_Model()
    vae_model.build(input_shape=(batch_size, image_shape))  # 注意这里是元组而不是列表,如[50,784]是错误的,会分成2个张量
    vae_model.summary()

    for epoch in range(epochs):
        for step, x_data in enumerate(train_db):

            x = tf.reshape(x_data, [-1, image_shape])

            with tf.GradientTape() as tape:

                out, mean, log_var = vae_model(x)

                # 重构图像与原始图像的交叉熵损失
                crossentropy_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=out), axis=1)
                crossentropy_meanloss = tf.reduce_mean(crossentropy_loss)

                # (mean,var) 与 epsion~N(0,1)的KL散度损失
                kl_loss = -0.5 * tf.reduce_sum((log_var + 1 - mean ** 2 - tf.exp(log_var)), axis=1)
                kl_meanloss = tf.reduce_mean(kl_loss)

                # 总损失
                loss = crossentropy_meanloss + kl_meanloss

            grads = tape.gradient(loss, vae_model.trainable_variables)
            optimizer.apply_gradients(zip(grads, vae_model.trainable_variables))

            if step % 100 == 0:
                print("epoch[{}/{}]  step[{}/{}]  reconstruct loss: {:.4f}  kl loss: {:.4f}  crossentropy loss: {:.4f}"
                      .format(epoch+1, epochs, step, int(x_train.shape[0] / batch_size), loss, kl_meanloss,
                              crossentropy_meanloss))


        # 转换成像素值
        x_reshaped = tf.nn.sigmoid(out)
        # 恢复28*28的形状
        x_reshaped = tf.reshape(x_reshaped, [-1, 28, 28])
        # 原始图像和重构图像合并,各占50张
        x_concat = tf.concat([x_data[:50], x_reshaped[:50]], axis=0)
        # 转换成numpy.array数据,并映射到0-255,且转换为uint8形式
        x_concat = np.array(x_concat) * 255
        x_concat = x_concat.astype(np.uint8)
        save_figure(x_concat, 'images\\image_epoch_%d.png' % (epoch+1))
        print("Epoch {} image has saved in the path: images\\image_epoch_{}.png".format(epoch+1, epoch+1))


        # 每个epoch保存一次模型
        vae_model.save_weights('model\\VEA_Model_epoch_%d.ckpt' % (epoch+1)) # 由于是submodel class 不能直接使用save方法保存.h5格式的模型文件
        print("Epoch {} model has saved in the path: \\model".format(epoch+1))


        # 测试噪声通过decoder生成的图像
        z_noise = tf.random.normal(shape=[batch_size, z_shape])
        noise_recons = vae_model.decoder(z_noise)

        # 转换成像素值
        noise_recons = tf.nn.sigmoid(noise_recons)
        # 恢复28*28的形状
        noise_recons = tf.reshape(noise_recons, [-1, 28, 28])
        # 转换为numpy.array形式的数据,并映射到0-255,且转换成uint8格式
        noise_recons = np.array(noise_recons) * 255
        noise_recons = noise_recons.astype(np.uint8)
        save_figure(noise_recons, 'images\\noise_recons_epoch_%d.png' % (epoch+1))
        print("Epoch {} image of noise reconstructing has saved in the path: images\\noise_recons_epoch_{}.png"
              .format(epoch + 1, epoch + 1))


if __name__ == '__main__':
    main()
# 运行结果
Model: "vae__model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  401920    
_________________________________________________________________
dense_1 (Dense)              multiple                  10260     
_________________________________________________________________
dense_2 (Dense)              multiple                  10260     
_________________________________________________________________
dense_3 (Dense)              multiple                  10752     
_________________________________________________________________
dense_4 (Dense)              multiple                  402192    
=================================================================
Total params: 835,384
Trainable params: 835,384
Non-trainable params: 0
_________________________________________________________________
epoch[1/20]  step[0/600]  reconstruct loss: 550.5972  kl loss: 2.8135  crossentropy loss: 547.7836
epoch[1/20]  step[100/600]  reconstruct loss: 300.4324  kl loss: 14.1489  crossentropy loss: 286.2836
epoch[1/20]  step[200/600]  reconstruct loss: 269.3903  kl loss: 16.1698  crossentropy loss: 253.2205
epoch[1/20]  step[300/600]  reconstruct loss: 266.6666  kl loss: 15.6042  crossentropy loss: 251.0624
epoch[1/20]  step[400/600]  reconstruct loss: 267.5359  kl loss: 16.2817  crossentropy loss: 251.2542
epoch[1/20]  step[500/600]  reconstruct loss: 269.5238  kl loss: 16.0071  crossentropy loss: 253.5167
Epoch 1 image has saved in the path: images\image_epoch_1.png
Epoch 1 model has saved in the path: model
Epoch 1 image of noise reconstructing has saved in the path: images\noise_recons_epoch_1.png
epoch[2/20]  step[0/600]  reconstruct loss: 245.8026  kl loss: 16.4432  crossentropy loss: 229.3594
epoch[2/20]  step[100/600]  reconstruct loss: 258.3123  kl loss: 16.3746  crossentropy loss: 241.9377
epoch[2/20]  step[200/600]  reconstruct loss: 252.6373  kl loss: 16.7064  crossentropy loss: 235.9309
epoch[2/20]  step[300/600]  reconstruct loss: 240.0333  kl loss: 16.5543  crossentropy loss: 223.4790
epoch[2/20]  step[400/600]  reconstruct loss: 257.7153  kl loss: 16.2433  crossentropy loss: 241.4720
epoch[2/20]  step[500/600]  reconstruct loss: 255.9624  kl loss: 16.3426  crossentropy loss: 239.6198
Epoch 2 image has saved in the path: images\image_epoch_2.png
Epoch 2 model has saved in the path: model
Epoch 2 image of noise reconstructing has saved in the path: images\noise_recons_epoch_2.png
epoch[3/20]  step[0/600]  reconstruct loss: 251.1254  kl loss: 16.1641  crossentropy loss: 234.9614
epoch[3/20]  step[100/600]  reconstruct loss: 245.0857  kl loss: 17.1267  crossentropy loss: 227.9590
epoch[3/20]  step[200/600]  reconstruct loss: 245.2651  kl loss: 16.2374  crossentropy loss: 229.0277
epoch[3/20]  step[300/600]  reconstruct loss: 252.0008  kl loss: 16.3147  crossentropy loss: 235.6861
epoch[3/20]  step[400/600]  reconstruct loss: 260.2074  kl loss: 16.5105  crossentropy loss: 243.6969
epoch[3/20]  step[500/600]  reconstruct loss: 237.4382  kl loss: 15.9900  crossentropy loss: 221.4482
Epoch 3 image has saved in the path: images\image_epoch_3.png
Epoch 3 model has saved in the path: model
Epoch 3 image of noise reconstructing has saved in the path: images\noise_recons_epoch_3.png
epoch[4/20]  step[0/600]  reconstruct loss: 246.2724  kl loss: 16.1399  crossentropy loss: 230.1325
epoch[4/20]  step[100/600]  reconstruct loss: 242.9304  kl loss: 15.9318  crossentropy loss: 226.9986
epoch[4/20]  step[200/600]  reconstruct loss: 247.8724  kl loss: 16.4313  crossentropy loss: 231.4411
epoch[4/20]  step[300/600]  reconstruct loss: 244.6034  kl loss: 16.1179  crossentropy loss: 228.4855
epoch[4/20]  step[400/600]  reconstruct loss: 252.1622  kl loss: 15.8833  crossentropy loss: 236.2789
epoch[4/20]  step[500/600]  reconstruct loss: 241.3502  kl loss: 16.0111  crossentropy loss: 225.3391
Epoch 4 image has saved in the path: images\image_epoch_4.png
Epoch 4 model has saved in the path: model
Epoch 4 image of noise reconstructing has saved in the path: images\noise_recons_epoch_4.png
epoch[5/20]  step[0/600]  reconstruct loss: 242.7823  kl loss: 16.1403  crossentropy loss: 226.6420
epoch[5/20]  step[100/600]  reconstruct loss: 238.2191  kl loss: 15.2946  crossentropy loss: 222.9244
epoch[5/20]  step[200/600]  reconstruct loss: 263.1482  kl loss: 16.3715  crossentropy loss: 246.7767
epoch[5/20]  step[300/600]  reconstruct loss: 250.9017  kl loss: 15.9819  crossentropy loss: 234.9198
epoch[5/20]  step[400/600]  reconstruct loss: 253.9848  kl loss: 16.1483  crossentropy loss: 237.8365
epoch[5/20]  step[500/600]  reconstruct loss: 238.0354  kl loss: 16.2313  crossentropy loss: 221.8042
Epoch 5 image has saved in the path: images\image_epoch_5.png
Epoch 5 model has saved in the path: model
Epoch 5 image of noise reconstructing has saved in the path: images\noise_recons_epoch_5.png
………………………………

训练至20个epoch的原始图像的重构图像,左5列为原始图像,右5列为重构图像

Tensorflow 2.0实现基于Fashion MNIST数据的VAE网络_第1张图片

训练至20个epoch的噪声的生成图像

Tensorflow 2.0实现基于Fashion MNIST数据的VAE网络_第2张图片

你可能感兴趣的:(TensorFlow,2.x,学习笔记)