keras框架下的VAE的代码撰写

20221117 -

0. 引言

在网上搜索keras版本的vae代码,在网上可以找到非常多,而且keras官方文档也提供了一个版本,但是这个版本是通过定义了新的层来进行训练的,跟其他的代码并不太一样,比较常见的代码形式就是如下形式(代码来源于[1]):

x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)

# 算p(Z|X)的均值和方差
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

# 重参数技巧
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=K.shape(z_mean))
    return z_mean + K.exp(z_log_var / 2) * epsilon

# 重参数层,相当于给输入加入噪声
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# 解码层,也就是生成器部分
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

# 建立模型
vae = Model(x, x_decoded_mean)

# xent_loss是重构loss,kl_loss是KL loss
xent_loss = K.sum(K.binary_crossentropy(x, x_decoded_mean), axis=-1)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)

# add_loss是新增的方法,用于更灵活地添加各种loss
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()

上述代码从逻辑上来说是没有什么问题的,在定义了编码器和解码器之后,在定义重构和kl损失之后,并对模型进行编译。而这里由于KL损失的形式并没有keras中自定义损失的那种预测和输出的形式,所以要采用这种形式。add_loss这个函数加入之后,训练是没有任何问题,能够正常训练,不管是在样本生成,或者是说,异常检测的代码中,都能看到这种形式。但这种形式引发了一个问题,就是我的损失函数在最后输出的时候,只有一个结果,也就是不能看到重构损失和kl损失的分别数值,这样的话,对于想试试查看具体的运行结果并不是很方便。在很久之前使用这个代码的时候,没有想这么多,就没管,最近翻出来这部分东西,就搜索了这部分内容。那么,为了打印分别损失函数的数值,这里就需要一些代码加入。

1. 修正损失打印

问答[2]中给出了答案。

reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
kl_loss = beta*K.mean(- 0.5 * 1/latent_dim * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1))

model.add_loss(reconstruction_loss)
model.add_loss(kl_loss)

model.add_metric(kl_loss, name='kl_loss', aggregation='mean')
model.add_metric(reconstruction_loss, name='mse_loss', aggregation='mean')

model.compile(optimizer='adam')

上述代码中,增加的部分就是add_metric函数,这个函数能够帮助后续打印的时候,打印具体的数值。

参考

[1]VAE_Keras
[2]Output multiple losses added by add_loss in Keras

你可能感兴趣的:(深度学习,keras,深度学习,人工智能)