VAE包括一下几部分:
1)一个encoder网络,对后验分布 q(z|x) 进行参数化,z 是离散隐随机变量,x 为输入数据;
2)先验分布 p(z);
3)一个decoder网络,它的输入数据分布是 p(x|z)。
使用离散隐变量,受向量量化的启发用一种新的方式进行训练。后验和先验分布是明确分类的,从这些分布中提取的样本可通过嵌入表进行索引。然后将这些嵌入作为解码器网络的输入。
网络结构如下:
首先输入x图像encoder得到特征 ze(x) ,之后查找Embedding Space中与特征 ze(x) 最相似的特征特征 zq(x) ,作为decoder的输入,之后将特征 zq(x) 输入到decoder得到输出.梯度 ∇zL (红色)使编码器改变输出.
encoder为三个卷积层,卷积核大小为 4×4 ,stride=2,decoder为三个反卷积层,卷积核大小为 4×4 ,stride=2,代码如下:
def _mnist_arch(d):
with tf.variable_scope('enc') as enc_param_scope :
enc_spec = [
Conv2d('conv2d_1',1,d//4,data_format='NHWC'),
lambda t,**kwargs : tf.nn.relu(t),
Conv2d('conv2d_2',d//4,d//2,data_format='NHWC'),
lambda t,**kwargs : tf.nn.relu(t),
Conv2d('conv2d_3',d//2,d,data_format='NHWC'),
lambda t,**kwargs : tf.nn.relu(t),
]
with tf.variable_scope('dec') as dec_param_scope :
dec_spec = [
TransposedConv2d('tconv2d_1',d,d//2,data_format='NHWC'),
lambda t,**kwargs : tf.nn.relu(t),
TransposedConv2d('tconv2d_2',d//2,d//4,data_format='NHWC'),
lambda t,**kwargs : tf.nn.relu(t),
TransposedConv2d('tconv2d_3',d//4,1,data_format='NHWC'),
lambda t,**kwargs : tf.nn.sigmoid(t),
]
return enc_spec,enc_param_scope,dec_spec,dec_param_scope
对输入x,先encoder,得到输出z_e,
# Encoder Pass
_t = x
for block in enc_spec :
_t = block(_t)
z_e = _t
之后将其压缩或离散化,得到矩阵z-q,
# Middle Area (Compression or Discretize)
# TODO: Gross.. use brodcast instead!
_t = tf.tile(tf.expand_dims(z_e,-2),[1,1,1,K,1]) #[batch,latent_h,latent_w,K,D]
_e = tf.reshape(embeds,[1,1,1,K,D])
_t = tf.norm(_t-_e,axis=-1)
k = tf.argmin(_t,axis=-1) # -> [latent_h,latent_w]
z_q = tf.gather(embeds,k)
self.z_e = z_e # -> [batch,latent_h,latent_w,D]
self.k = k
self.z_q = z_q # -> [batch,latent_h,latent_w,D]
之后将矩阵z_q decoder,得到输出self.p_x_z,
# Decoder Pass
_t = z_q
for block in dec_spec:
_t = block(_t)
self.p_x_z = _t
损失函数为,输出与输入尽量相似,同时,矩阵z_q与z_e进行相似,
# Losses
self.recon = tf.reduce_mean((self.p_x_z - x)**2,axis=[0,1,2,3])
self.vq = tf.reduce_mean(
tf.norm(tf.stop_gradient(self.z_e) - z_q,axis=-1)**2,
axis=[0,1,2])
self.commit = tf.reduce_mean(
tf.norm(self.z_e - tf.stop_gradient(z_q),axis=-1)**2,
axis=[0,1,2])
self.loss = self.recon + self.vq + beta * self.commit
# Decoder Grads
decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,dec_param_scope.name)
decoder_grads = list(zip(tf.gradients(self.loss,decoder_vars),decoder_vars))
# Encoder Grads
encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,enc_param_scope.name)
grad_z = tf.gradients(self.recon,z_q)
encoder_grads = [(tf.gradients(z_e,var,grad_z)[0]+beta*tf.gradients(self.commit,var)[0],var)
for var in encoder_vars]
# Embedding Grads
embed_grads = list(zip(tf.gradients(self.vq,embeds),[embeds]))