变分自编码器是自编码器的变种,是一种生成模型。
VAE_main.py (VAE主函数)
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from torch import nn,optim
from VAE import VAE
import visdom
def main():
mnist_train=datasets.MNIST('mnist',True,transform=transforms.Compose([
transforms.ToTensor()
]),download=True)
mnist_train=DataLoader(mnist_train,batch_size=32,shuffle=True)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
#无监督学习,不需要label
x,_=iter(mnist_train).next()
print('x:',x.shape)
model=VAE()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr=1e-3)
print(model)
viz=visdom.Visdom()
for epoch in range(100):
for batch_size,(x,_) in enumerate(mnist_train):
x_hat,kld=model(x)
loss=criteon(x_hat,x)
if kld is not None:
elbo=-loss-1.0*kld
loss=-elbo
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch:',epoch,'loss:',loss.item(),'kld:',kld.item())
x,_=iter(mnist_test).next() #其中x是标签,_是label
with torch.no_grad():
x_hat,kld=model(x)
viz.images(x,nrow=8,win='x',opts=dict(title='x'))
viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))
if __name__ == '__main__':
main()
VAE.py (定义VAE网络结构)
import torch
from torch import nn
import numpy as np
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
#[b,784]=>[b,20]
#u:[b,10]
#sigma:[b,10]
self.encoder=nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,20),
nn.ReLU()
)
#[b,20]=>[b,784]
self.decoder=nn.Sequential(
nn.Linear(10,64),
nn.ReLU(),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,784),
nn.Sigmoid()
)
def forward(self,x):
batch_size=x.size(0)
#flatten
x=x.view(batch_size,784)
#encoder
#[b,20],including mean and sigma
h_=self.encoder(x)
#[b,20]=>[b,10] and [b,10]
mu,sigma=h_.chunk(2,dim=1) #chunk是拆分,dim=1是在一维上拆分
#reparametrize trick,解决不能sample的问题 ,epison~N(0,1)
h=mu+sigma*torch.rand_like(sigma) #torch.rand_like(sigma) 就是正态分布
# decoder
x_hat=self.decoder(h)
#reshape
x_hat=x_hat.view(batch_size,1,28,28)
# 下面是KL散度
# kl divergence
kld = 0.5 * torch.sum(
torch.pow(mu, 2) +
torch.pow(sigma, 2) -
torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (batch_size * 28 * 28)
return x_hat,kld
在运行VAE_main.py文件之前需要在控制台运行如下代码,打开visdom
python -m visdom.server