目录
1. 引言
2.结构
3.搭建网络
4.代码
今天我们来学习一种在压缩数据方面有较好效果的神经网络:自编码(autoencoder)。
自编码的主要思想就是将数据先不断encode进行降维提取其中的关键信息,再decode解码成新的信息,我们的目标是要使我们生成的信息和原信息尽可能相似,有点像我们做题一样,先刷大量的题,提取其中的关键解题思路,遇到同类型的题目就会写了,基本结构如下
这个网络先将信息不断压缩,再解压,对比原始数据和新数据的差别,再反向传播修正参数,最后输出的新数据就会越来越接近原始数据,最后最中间的一层就是这组信息的关键特征。
我们以输出手写数字为例,我们还是只看网络的搭建过程:
#搭建网络
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(),
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
我们需要构建编码和解码两个网络,编码部分就是将数据一层一层压缩,解码部分就是对应着一层层解压,中间加入非线性函数,最后再加一个激活函数,但是要确保数据的范围与原来一样,因为我们的目标是生成与原数据一样的数据,最后前向传播先编码再解码即可。
下面是效果
开始:
中间:
最后:
可以看出输出图片与原始图片越来越接近了。
最后的损失:
loss:0.33
#调库
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
import numpy as np
#超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005 # learning rate
DOWNLOAD_MNIST = False
N_TEST_IMG = 5
#下载数据
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
#小批数据
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
#搭建网络
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(),
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
autoencoder = AutoEncoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()
view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())
#训练并绘图
for epoch in range(EPOCH):
for step, (x, b_label) in enumerate(train_loader):
b_x = x.view(-1, 28*28)
b_y = x.view(-1, 28*28)
encoded, decoded = autoencoder(b_x)
loss = loss_func(decoded, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
_, decoded_data = autoencoder(view_data)
for i in range(N_TEST_IMG):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
a[1][i].set_xticks(()); a[1][i].set_yticks(())
plt.draw(); plt.pause(0.05)
plt.ioff()
plt.show()