使用pytorch的LSTM实现MNIST数据集分类任务

 使用pytorch的LSTM实现MNIST数据集分类任务

"""
__author__:shuangrui Guo
__description__:
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader


class Rnn_LSTM(nn.Module):
    def __init__(self,input_dim,hidden_dim,n_layers,n_classes):
        super(Rnn_LSTM,self).__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim,hidden_dim,n_layers,batch_first=True)
        self.classifier = nn.Linear(hidden_dim,n_classes)
    def forward(self,x):
        out,(h_n,c_n) = self.lstm(x)
        x = h_n[-1,:,:]
        x = self.classifier(x)
        return x

#训练与测试代码
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5])
])

train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform = transform)
train_loader = DataLoader(train_set,batch_size=128,shuffle=True)

test_set = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform = transform)
test_loader = DataLoader(test_set,batch_size=100,shuffle=False)

net = Rnn_LSTM(28,10,2,10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9)


#训练
def train(epoch):
    print(f'epoch:{epoch}')
    net.train()
    train_loss=0
    correct = 0
    total = 0
    for batch_index,(inputs,targets) in enumerate(train_loader):
        inputs,targets = inputs.to('cpu'),targets.to('cpu')
        optimizer.zero_grad()
        outputs = net(torch.squeeze(inputs,1))
        loss = criterion(outputs,targets)
        loss.backward()
        optimizer.step()
        train_loss +=loss.item()
        _,predicted = outputs.max(1)
        total+=targets.size(0)
        correct +=predicted.eq(targets).sum().item()
        print(batch_index,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(batch_index+1),100*correct/total,correct,total))
def test(epoch):
    global best_acc
    #Sets the module in evaluation mode.
    #如果在自己的网络module里面使用到了BN(加速训练)和Dropout正则化
    #那么在推理(predict)阶段,你需要用到eval()方法,告诉模型“我要开始预测了,你把mode换一下“
    #这样你网络输出的预测结果才能与你的测试集数据相对应。
    net.eval()
    test_loss=0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx,(inputs,targets) in enumerate(test_loader):
            inputs,targets = inputs.to('cpu'),targets.to('cpu')
            outputs = net(torch.squeeze(inputs,1))
            loss = criterion(outputs,targets)
            test_loss+=loss.item()
            _,predicted = outputs.max(1)
            print(predicted)
            total +=targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print(batch_idx,len(train_loader),'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(batch_idx+1),100*correct/total,correct,total))


for epoch in range(200):
    train(epoch)
    test(epoch)

 

你可能感兴趣的:(深度学习,深度学习,lstm,监督学习,自然语言处理,pytorch)