史上最全MNIST系列(六)——RNN与LSTM在MNIST上的Pytorch实现

文章目录

  • 一、理论
  • 二、代码实现
    • 2.1 nets.py
    • 2.2 Train.py
    • 2.3 测试与损失展示

一、理论

二、代码实现

代码目录
史上最全MNIST系列(六)——RNN与LSTM在MNIST上的Pytorch实现_第1张图片

2.1 nets.py

import torch.nn as nn
import torch

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.lstm = nn.LSTM(
            input_size=28,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )

        self.out = nn.Linear(128 * 2, 10)

    def forward(self, x):
        r_out, (h_t, c_t) = self.lstm(x, None)
        out = self.out(r_out[:, -1, :])
        return out

2.2 Train.py

from torchvision.datasets import MNIST
import torch
import torch.nn as nn
from torch.utils import data
from torchvision import transforms
from nets import MyNet
import os
import matplotlib.pyplot as plt

class Trainer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.loss_fn = nn.CrossEntropyLoss()
        self.net = MyNet().to(self.device)
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1304,), (0.3084,))
        ])
        self.opt = torch.optim.Adam(self.net.parameters())

    def train(self):
        NUM_EPOCHS = 10
        BATCH_SIZE = 100
        save_path = r"models/rnn.pth"
        if os.path.exists(save_path):
            self.net.load_state_dict(torch.load(save_path))
        else:
            print("No Params")
        train_data = MNIST(root="./MNIST", train=True, download=False, transform=self.trans)

        train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
        losses = []
        for epochs in range(NUM_EPOCHS):
            for i, (x, y) in enumerate(train_loader):
                x = x.reshape(-1, 28,28)
                img = x.to(self.device)
                label = y.to(self.device)
                out = self.net(img)
                loss = self.loss_fn(out, label)

                if i % 50 == 0:
                    print("epochs:{}/{},iteration:{}/{},loss:{:.3f}".
                          format(epochs,NUM_EPOCHS,i,len(train_loader),loss.item()))
                    losses.append(loss.float())
                    plt.clf()
                    plt.title("Loss")
                    plt.plot(losses)
                    plt.pause(0.01)
                    plt.savefig("loss.jpg")
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
            torch.save(self.net.state_dict(),save_path)

    def test(self):
        BATCH_SIZE = 100
        test_data = MNIST(root="./MNIST", train=False, download=False, transform=self.trans)
        test_loader = data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
        save_path = "models/rnn.pth"
        self.net.load_state_dict(torch.load(save_path))
        self.net.eval()
        eval_loss = 0
        eval_acc = 0
        for x,y in test_loader:
            x = x.reshape(-1,28,28)
            img = x.to(self.device)
            label = y .to(self.device)
            out = self.net(img)
            loss = self.loss_fn(out,label)

            eval_loss += loss.item() * label.size(0)
            argmax = torch.argmax(out, 1)

            num_acc = (argmax == label).sum()
            eval_acc += num_acc.item()
        print(torch.argmax(out, 1))
        print(label)
        print('Test Loss: {:.3f}, Acc: {:.3f}%'
              .format(eval_loss / (len(test_data)),
                      eval_acc / (len(test_data)) * 100))
if __name__ == '__main__':
    t = Trainer()
    # t.train()
    t.test()

2.3 测试与损失展示

tensor([8, 2, 0, 8, 0, 9, 0, 5, 7, 8, 1, 5, 6, 7, 2, 7, 1, 5, 0, 7, 8, 9, 5, 1,
        1, 6, 3, 1, 5, 4, 6, 4, 8, 7, 2, 4, 3, 9, 2, 2, 3, 9, 9, 3, 1, 4, 7, 7,
        9, 4, 5, 0, 7, 4, 1, 0, 4, 8, 4, 1, 8, 7, 4, 3, 0, 7, 5, 6, 1, 2, 2, 1,
        5, 4, 2, 2, 5, 9, 3, 6, 3, 8, 4, 2, 4, 0, 7, 9, 7, 9, 2, 0, 0, 2, 1, 8,
        8, 5, 6, 5], device='cuda:0')
tensor([8, 2, 0, 8, 0, 9, 0, 5, 7, 8, 1, 5, 6, 7, 2, 7, 1, 5, 0, 7, 8, 9, 5, 1,
        1, 6, 3, 1, 5, 4, 6, 4, 8, 7, 2, 4, 3, 9, 2, 2, 3, 9, 9, 3, 1, 4, 7, 7,
        9, 4, 5, 0, 7, 4, 1, 0, 4, 8, 4, 1, 8, 7, 4, 3, 0, 7, 5, 6, 1, 2, 2, 1,
        5, 4, 2, 2, 5, 9, 3, 6, 3, 8, 4, 2, 4, 0, 7, 9, 7, 9, 2, 0, 0, 2, 1, 8,
        8, 5, 6, 5], device='cuda:0')
Test Loss: 0.040, Acc: 98.840%

史上最全MNIST系列(六)——RNN与LSTM在MNIST上的Pytorch实现_第2张图片

你可能感兴趣的:(深度学习,AI,pytorch)