基于pytorch的自动编码器源码

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
import numpy as np

#读取数据
train_data = MNIST(root=’./mnist/’,train=True,transform=tfs.ToTensor())#60000张训练集
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy())#生成第第1张图片,显示为彩色
plt.show()
train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)#分批并打乱顺序

#定义自动编码器
class Autoencoder(nn.Module):
def init(self):
super(Autoencoder, self).init()

    self.encoder = nn.Sequential(
        nn.Linear(784, 400),
        nn.ReLU(True),
        nn.Linear(400, 200),
        nn.ReLU(True),
        nn.Linear(200, 100),
        nn.ReLU(True),
        nn.Linear(100,3)
    )

    self.decoder = nn.Sequential(
        nn.Linear(3, 100),
        nn.ReLU(True),
        nn.Linear(100, 200),
        nn.ReLU(True),
        nn.Linear(200, 400),
        nn.ReLU(True),
        nn.Linear(400, 784),
        nn.Tanh()
    )
def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return encoded, decoded

autoencoder = Autoencoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.005) #优化方式
loss_func = nn.MSELoss() #损失函数 均方误差

#创建一个画布
f, a = plt.subplots(2, 10, figsize=(10, 2))
plt.ion()

用于查看原始数据

view_data = train_data.train_data[:10].view(-1, 28*28).type(torch.Tensor)/255
#print(view_data)
for i in range(10):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)))
a[0][i].set_xticks(())
a[0][i].set_yticks(()) #设置位置

for epoch in range(5):
for step, (x, b_label) in enumerate(train_loader):
#print(x.shape) #64,1,28,28
b_x = x.view(-1, 2828) # batch x, shape (batch, 2828)
#print(b_x.shape) #64784
b_y = x.view(-1, 28
28) # batch y, shape (batch, 28*28)

    encoded, decoded = autoencoder(b_x)

    loss = loss_func(decoded, b_y)      # 计算损失函数
    optimizer.zero_grad()               # 梯度清零
    loss.backward()                     # 反向传播
    optimizer.step()                    # 梯度优化

    if step % 100 == 0:        #每100步显示一次
        print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

        #绘制解码图像
        encoded_data, decoded_data = autoencoder(view_data)
        #print(encoded_data.shape)
        for i in range(10):
            a[1][i].clear()
            a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)))
            a[1][i].set_xticks(()); a[1][i].set_yticks(())
        plt.draw(); plt.pause(0.05)#暂停0.05秒

plt.ioff()
plt.show()

你可能感兴趣的:(基于pytorch的自动编码器源码)