VAE(Variational Autoencoder)简单记录

前言

经常遇到它,然而每次小补之后又忘了,害,干脆开一篇慢慢记录一下吧。
VAE -> VQVAE, 主要是加了Vector Quantization

本文会不断更新…

这篇写的不错,有空好好看看 变分自编码器VAE:原来是这么一回事 | 附开源代码, 苏剑林大佬的文章。


理论



代码

以下代码来源: https://zhuanlan.zhihu.com/p/151587288

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 128))
        
        self.mu     = nn.Linear(128, latent_dim)
        self.logvar = nn.Linear(128, latent_dim)
        
        self.latent_mapping = nn.Linear(latent_dim, 128)
        
        self.decoder = nn.Sequential(nn.Linear(128, 256),
                                     nn.ReLU(),
                                     nn.Linear(256, 28 * 28))
        
        
    def encode(self, x):
        x = x.view(x.size(0), -1)
        encoder = self.encoder(x)
        mu, logvar = self.mu(encoder), self.logvar(encoder)
        return mu, logvar
        
    def sample_z(self, mu, logvar):
        eps = torch.rand_like(mu)
        return mu + eps * torch.exp(0.5 * logvar)
    
    def decode(self, z,x):
        latent_z = self.latent_mapping(z)
        out = self.decoder(latent_z)
        reshaped_out = torch.sigmoid(out).view(x.shape[0],1, 28,28)
        return reshaped_out
        
    def forward(self, x):
        
        mu, logvar = self.encode(x)
        z = self.sample_z(mu, logvar)
        output = self.decode(z,x)
        
        return output

你可能感兴趣的:(#,图像生成模型,VAE,计算机视觉,深度学习)