自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning)。
通过算法模型包含两个主要的部分:
Encoder(编码器)和Decoder(解码器)
。编码器的作用是把高维输入 X 编码成低维的隐变量 h ,从而强迫神经网络学习最有信息量的特征;解码器的作用是把隐藏层的隐变量 h 还原到初始维度,最好的状态就是解码器的输出能够完美地或者近似恢复出原来的输入。通过AE可以实现降维的作用。
AE_main.py (定义主函数和迭代次数)
import torch
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from torch import nn,optim
from AutoEncoder import AE
import visdom
def main():
mnist_train=datasets.MNIST('mnist',True,transform=transforms.Compose([
transforms.ToTensor()
]),download=True)
mnist_train=DataLoader(mnist_train,batch_size=32,shuffle=True)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
x,_=iter(mnist_train).next()
print('x:',x.shape)
model=AE()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr=1e-3)
print(model)
viz=visdom.Visdom()
for epoch in range(1000):
for batch_size,(x,_) in enumerate(mnist_train):
x_hat=model(x)
loss=criteon(x_hat,x)
#backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch,'loss:',loss.item())
x,_=iter(mnist_test).next() #其中x是标签,_是label
with torch.no_grad():
x_hat=model(x)
viz.images(x,nrow=8,win='x',opts=dict(title='x'))
viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))
if __name__ == '__main__':
main()
AutoEncoder.py (定义网络架构)
import torch
from torch import nn
class AE(nn.Module):
def __init__(self):
super(AE, self).__init__()
#[b,784]=>[b,20]
self.encoder=nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,20),
nn.ReLU()
)
#[b,20]=>[b,784]
self.decoder=nn.Sequential(
nn.Linear(20,64),
nn.ReLU(),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,784),
nn.Sigmoid()
)
def forward(self,x):
batch_size=x.size(0)
#flatten
x=x.view(batch_size,784)
#encoder
x=self.encoder(x)
# decoder
x=self.decoder(x)
#reshape
x=x.view(batch_size,1,28,28)
return x
在运行AE_main.py文件之前需要在控制台运行如下代码,打开visdom
python -m visdom.server
打开这个 http://localhost:8097 网站即可查看训练的过程图片