import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
import numpy as np
#读取数据
train_data = MNIST(root=’./mnist/’,train=True,transform=tfs.ToTensor())#60000张训练集
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy())#生成第第1张图片,显示为彩色
plt.show()
train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)#分批并打乱顺序
#定义自动编码器
class Autoencoder(nn.Module):
def init(self):
super(Autoencoder, self).init()
self.encoder = nn.Sequential(
nn.Linear(784, 400),
nn.ReLU(True),
nn.Linear(400, 200),
nn.ReLU(True),
nn.Linear(200, 100),
nn.ReLU(True),
nn.Linear(100,3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 100),
nn.ReLU(True),
nn.Linear(100, 200),
nn.ReLU(True),
nn.Linear(200, 400),
nn.ReLU(True),
nn.Linear(400, 784),
nn.Tanh()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
autoencoder = Autoencoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.005) #优化方式
loss_func = nn.MSELoss() #损失函数 均方误差
#创建一个画布
f, a = plt.subplots(2, 10, figsize=(10, 2))
plt.ion()
view_data = train_data.train_data[:10].view(-1, 28*28).type(torch.Tensor)/255
#print(view_data)
for i in range(10):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)))
a[0][i].set_xticks(())
a[0][i].set_yticks(()) #设置位置
for epoch in range(5):
for step, (x, b_label) in enumerate(train_loader):
#print(x.shape) #64,1,28,28
b_x = x.view(-1, 2828) # batch x, shape (batch, 2828)
#print(b_x.shape) #64784
b_y = x.view(-1, 2828) # batch y, shape (batch, 28*28)
encoded, decoded = autoencoder(b_x)
loss = loss_func(decoded, b_y) # 计算损失函数
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 梯度优化
if step % 100 == 0: #每100步显示一次
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
#绘制解码图像
encoded_data, decoded_data = autoencoder(view_data)
#print(encoded_data.shape)
for i in range(10):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)))
a[1][i].set_xticks(()); a[1][i].set_yticks(())
plt.draw(); plt.pause(0.05)#暂停0.05秒
plt.ioff()
plt.show()