使用pytorch的LSTM实现MNIST数据集分类任务
"""
__author__:shuangrui Guo
__description__:
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
class Rnn_LSTM(nn.Module):
def __init__(self,input_dim,hidden_dim,n_layers,n_classes):
super(Rnn_LSTM,self).__init__()
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim,hidden_dim,n_layers,batch_first=True)
self.classifier = nn.Linear(hidden_dim,n_classes)
def forward(self,x):
out,(h_n,c_n) = self.lstm(x)
x = h_n[-1,:,:]
x = self.classifier(x)
return x
#训练与测试代码
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform = transform)
train_loader = DataLoader(train_set,batch_size=128,shuffle=True)
test_set = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform = transform)
test_loader = DataLoader(test_set,batch_size=100,shuffle=False)
net = Rnn_LSTM(28,10,2,10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9)
#训练
def train(epoch):
print(f'epoch:{epoch}')
net.train()
train_loss=0
correct = 0
total = 0
for batch_index,(inputs,targets) in enumerate(train_loader):
inputs,targets = inputs.to('cpu'),targets.to('cpu')
optimizer.zero_grad()
outputs = net(torch.squeeze(inputs,1))
loss = criterion(outputs,targets)
loss.backward()
optimizer.step()
train_loss +=loss.item()
_,predicted = outputs.max(1)
total+=targets.size(0)
correct +=predicted.eq(targets).sum().item()
print(batch_index,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(batch_index+1),100*correct/total,correct,total))
def test(epoch):
global best_acc
#Sets the module in evaluation mode.
#如果在自己的网络module里面使用到了BN(加速训练)和Dropout正则化
#那么在推理(predict)阶段,你需要用到eval()方法,告诉模型“我要开始预测了,你把mode换一下“
#这样你网络输出的预测结果才能与你的测试集数据相对应。
net.eval()
test_loss=0
correct = 0
total = 0
with torch.no_grad():
for batch_idx,(inputs,targets) in enumerate(test_loader):
inputs,targets = inputs.to('cpu'),targets.to('cpu')
outputs = net(torch.squeeze(inputs,1))
loss = criterion(outputs,targets)
test_loss+=loss.item()
_,predicted = outputs.max(1)
print(predicted)
total +=targets.size(0)
correct += predicted.eq(targets).sum().item()
print(batch_idx,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(batch_idx+1),100*correct/total,correct,total))
for epoch in range(200):
train(epoch)
test(epoch)