本题目尝试用 VAE 生成 MNIST 风格的手写数字。与 AE 不同,VAE 试图从特定分布中采样出一个隐变量,交由解码器学习一个与观测数据相同的分布。然后从学习得到的分布中采样得到新数据。采用 MNIST 数据训练一个 VAE 模型(卷积网络或多层感知机网络),并使用学习好的 VAE 模型,生成与训练数据相似的新图像,并将其打印出来。
损失函数实现
def loss_func(recon_x, x, mu, logvar):
BCE_loss = nn.BCELoss(reduction='sum')
recon_loss = BCE_loss(recon_x, x)
KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
return recon_loss + KLD
编码器部分能够学习到根据输入样本X来形成一个特定分布,从中我们可以对一个隐藏变量进行采样,而这个隐藏变量极有可能生成X里面的样本。为了使得Q(z|X)服从高斯分布,这部分需要被优化。
解码器部分能够学习到根据给定的一个隐藏变量z作为输入,生成一个具有真实数据分布的输出。该部分将经过采样后的z(最初来自正态分布)映射到一个更复杂的隐藏空间去(实际数据的空间),并通过这个复杂的隐藏变量z生成一个个的数据点,这些数据点十分接近真实数据的分布。
VAE工作流程图
实现代码
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(28 * 28, 400)
self.fc2_mean = nn.Linear(400, 20)
self.fc2_logvar = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 28 * 28)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def encode(self, x):
h1 = self.relu(self.fc1(x))
return self.fc2_mean(h1), self.fc2_logvar(h1)
def reparamertrize(self, mu, logvar):
std = torch.exp(logvar / 2)
eps = torch.rand_like(std)
return eps * std + mu
def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparamertrize(mu, logvar)
return mu, logvar, self.decode(z)
BATCH_SIZE = 128
train_data = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.ToTensor(),
download=True
)
train_loader = DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
LR = 1e-3
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
pca = decomposition.PCA()
# 记录loss变化
loss_list = []
for epoch in trange(EPOCH):
epoch_iterator = tqdm(train_loader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
# 由于是无监督模型,只采用data部分
data, targets = batch
real_imgs = data.view(-1, 28 * 28).to(device)
mu, logvar, gen_imgs = model(real_imgs)
loss = loss_func(gen_imgs, real_imgs, mu, logvar)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch {}, loss: {}'.format(epoch, loss))
loss_list.append(loss)
# 将原始图像与生成图像拼接绘制,左侧为原始图像,右侧为生成图像
concat_imgs = torch.cat([real_imgs.view(-1, 1, 28, 28),
gen_imgs.view(-1, 1, 28, 28)], dim=3)
save_image(concat_imgs, 'images/concat_image-{}.png'.format(epoch))
plt.plot(range(len(loss_list)), loss_list, label='loss')
plt.legend()
plt.show()
with torch.no_grad():
mu_re = pca.fit_transform(mu.cpu().numpy())[0, 0]
logvar_re = pca.fit_transform(logvar.cpu().numpy())[0, 0]
x = np.linspace(mu_re - 6 * logvar_re, mu_re + 6 * mu_re, 100)
y = normal_distribution(x, mu_re, logvar_re)
plt.plot(x, y, color='b')
plt.show()