import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import matplotlib.pyplot as plt
import torchvision
class AutoEncodeNet(nn.Module):
def __init__(self):
super(AutoEncodeNet, self).__init__()
# 编码
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3), # 压缩成3个特征, 进行 3D 图像可视化
)
# 解压
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(), # 激励函数让输出值在 (0, 1)
)
# 分类器
self.classfier = nn.Sequential(
nn.Linear(3,128),
nn.Tanh(),
nn.Linear(128,10),
nn.Sigmoid(),
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
lable = self.classfier(encoded)
return encoded, decoded,lable
def train():
# 超参数
EPOCH = 20
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_MNIST = False # 下过数据的话, 就可以设置成 False
N_TEST_IMG = 5 # 到时候显示 5张图片看效果, 如上图一
# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
autoencoder = AutoEncodeNet()
# autoencoder = torch.load("autoencoder_115.pkl")
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
# 编码损失函数
loss_func = nn.MSELoss()
# 分类损失函数
loss_func1 = nn.CrossEntropyLoss()
# 数据加载
train_loader = torch.utils.data.DataLoader(train_data,batch_size=128,shuffle=True)
losses =[]
fig,ax=plt.subplots(2,N_TEST_IMG)
plt.ion() # continuously plot
# 会出验证的五张原图
testImg = train_data.data[:5].view(-1,28,28).type(torch.FloatTensor)/255.
for i in range(5):
ax[0][i].imshow(testImg[i])
for epoch in range(EPOCH):
for step, (x, b_label) in enumerate(train_loader):
b_x = x.view(-1, 28*28) # batch x, shape (batch, 28*28)
b_y = x.view(-1, 28*28) # batch y, shape (batch, 28*28)
encoded, decoded ,lable= autoencoder(b_x)
# 求损失
loss = loss_func(decoded, b_y) + loss_func1(lable,b_label) # mean and onehot square error
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
# losses.append(loss.data.numpy())
# plt.cla()
# index=random.randint(100,110)
# print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
en,de,ll=autoencoder.forward(testImg.view(-1,28*28))
dded = de.view(-1,28,28)
for i in range(N_TEST_IMG):
ax[1][i].clear()
ax[1][i].imshow(dded[i].data.numpy())
lll=list(ll.data[i])
# print(lll)
print(lll.index(max(lll)),end=" , ")
print("------")
plt.draw()
plt.pause(0.01)
print(loss)
torch.save(autoencoder,"autoencoder_3"+epoch.__str__()+".pkl")
plt.show()
plt.ioff()
def test():
# Mnist digits dataset
DOWNLOAD_MNIST = False
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
net = torch.load("autoencoder_8.pkl")
index=random.randint(0,60000)
print(train_data.__getitem__(index)[0].view(-1,28*28).shape)
en,de=net.forward(train_data.__getitem__(index)[0].view(-1,28*28))
fig,[ax1,ax2]=plt.subplots(1,2)
# ax2=plt.subplots(1,2)
print(en)
ax1.imshow(train_data.__getitem__(index)[0].view(28,28).data.numpy())
ax2.imshow(de.view(28,28).data.numpy())
plt.show()
if __name__ == "__main__":
# test()
train()