【LeNet-5】手写数字识别MNIST

目录

  • 1、LeNet5 模型
  • 2、训练模型
  • 3、输出

1、LeNet5 模型

【LeNet-5】手写数字识别MNIST_第1张图片
【LeNet-5】手写数字识别MNIST_第2张图片

模型特点:每个卷积层包含3个部分:卷积、池化(Average Pooling)、非线性激活函数(Tanh)

class LeNet5(nn.Module):
    """ 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
    def __init__(self, in_channel, output):
        super(LeNet5, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2),   # (6, 28, 28)
                                    nn.Tanh(),
                                    nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (6, 14, 14))

        self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),  # (16, 10, 10)
                                    nn.Tanh(),
                                    nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (16, 5, 5)

        self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)  # (120, 1, 1)

        self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
                                    nn.Tanh(),
                                    nn.Linear(in_features=84, out_features=output))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = torch.flatten(input=x, start_dim=1)
        x = self.layer4(x)
        return x

2、训练模型

import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


train_batch_size = 12
test_batch_size = 48
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
# 下载数据 & 导入数据
train_set = mnist.MNIST("./mnist_data", train=True, download=True, transform=transform)
test_set = mnist.MNIST("./mnist_data", train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)

# # 抽样查看图片
# examples = enumerate(test_loader)
# batch_idex, (example_data, example_label) = next(examples)
# sample_set = np.array(example_data)
#
# for i in range(6):
#     plt.subplot(2, 3, i + 1)
#     plt.imshow(sample_set[i][0])
#     plt.title("Ground Truth: {}".format(example_label[i]))
# plt.show()


class LeNet5(nn.Module):
    """ 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
    def __init__(self, in_channel, output):
        super(LeNet5, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2),   # (6, 28, 28)
                                    nn.Tanh(),
                                    nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (6, 14, 14))

        self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),  # (16, 10, 10)
                                    nn.Tanh(),
                                    nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (16, 5, 5)

        self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)  # (120, 1, 1)

        self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
                                    nn.Tanh(),
                                    nn.Linear(in_features=84, out_features=output))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = torch.flatten(input=x, start_dim=1)
        x = self.layer4(x)
        return x

model = LeNet5(1, 10)
model.to(device)

lr = 0.01
num_epoches = 20
momentum = 0.8

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)


eval_losses = []
eval_acces = []

for epoch in range(num_epoches):

    if epoch % 5 == 0:
        optimizer.param_groups[0]['lr'] *= 0.1

    model.train()
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        predict = model(imgs)
        loss = criterion(predict, labels)

        # back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    eval_loss = 0
    eval_acc = 0
    model.eval()
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        predict = model(imgs)
        loss = criterion(predict, labels)

        # record loss
        eval_loss += loss.item()

        # record accurate rate
        result = torch.argmax(predict, axis=1)
        acc_num = (result == labels).sum().item()
        acc_rate = acc_num / imgs.shape[0]
        eval_acc += acc_rate

    eval_losses.append(eval_loss / len(test_loader))
    eval_acces.append(eval_acc / len(test_loader))

    print('epoch: {}'.format(epoch))
    print('loss: {}'.format(eval_loss / len(test_loader)))
    print('accurate rate: {}'.format(eval_acc / len(test_loader)))
    print('\n')

plt.title('evaluation loss')
plt.plot(np.arange(len(eval_losses)), eval_losses)
plt.show()

3、输出

【LeNet-5】手写数字识别MNIST_第3张图片
epoch: 0
loss: 0.20932436712157498
accurate rate: 0.9417862838915463

epoch: 1
loss: 0.1124003769263946
accurate rate: 0.9681020733652314

epoch: 2
loss: 0.0809573416740736
accurate rate: 0.9753787878787872

epoch: 3
loss: 0.07089491755452061
accurate rate: 0.9779704944178623

epoch: 4
loss: 0.05831286043338656
accurate rate: 0.9821570972886757

epoch: 5
loss: 0.05560500273351785
accurate rate: 0.9828548644338115

epoch: 6
loss: 0.0542455422597309
accurate rate: 0.9835526315789472

epoch: 7
loss: 0.05367041283908732
accurate rate: 0.9838516746411479

epoch: 8
loss: 0.05298826666370605
accurate rate: 0.9838516746411481

epoch: 9
loss: 0.05252152112530963
accurate rate: 0.9836523125996807

epoch: 10
loss: 0.05247020455629846
accurate rate: 0.9836523125996808

epoch: 11
loss: 0.05242454297127621
accurate rate: 0.9837519936204145

epoch: 12
loss: 0.05237526405083559
accurate rate: 0.9838516746411481

epoch: 13
loss: 0.05233189105290171
accurate rate: 0.9839513556618819

epoch: 14
loss: 0.05222674906053291
accurate rate: 0.9837519936204145

epoch: 15
loss: 0.052228276117072044
accurate rate: 0.9837519936204145

epoch: 16
loss: 0.05222897543727852
accurate rate: 0.9837519936204145

epoch: 17
loss: 0.05222897782574216
accurate rate: 0.9838516746411481

epoch: 18
loss: 0.05222847037079731
accurate rate: 0.9838516746411481

epoch: 19
loss: 0.05222745426054866
accurate rate: 0.9838516746411481

【LeNet-5】手写数字识别MNIST_第4张图片

你可能感兴趣的:(#,pytorch,经典网络学习,pytorch,深度学习,python)