Pytorch入门--详解Mnist手写字识别

1 什么是Mnist?

        Mnist是计算机视觉领域中最为基础的一个数据集。

        MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。Mnist中所有样本都会将原本28*28的灰度图转换为长度为784的一维向量作为输入,其中每个元素分别对应了灰度图中的灰度值。Mnist使用一个长度为10的向量作为该样本所对应的标签,其中向量索引值对应了该样本以该索引为结果的预测概率。
 

2、代码实现

需导入的python库

import torch
import scipy.misc
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torch import optim


import torch.nn as nn
import torch.nn.functional as F

构建模型

# 构建模型(简单的卷积神经网络)
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size =5, padding = 2) # 卷积
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Linear(in_feactures(输入的二维张量大小), out_feactures)
        self.fc1   = nn.Linear(16*5*5, 120) # 全连接
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10) # 最后输出10个类

    def forward(self, x):
        # 激活函数
        out = F.relu(self.conv1(x))
        # max_pool2d(input, kernel_size(卷积核), stride(卷积核步长)=None, padding=0, dilation=1, ceil_mode(空间输入形状)=False, return_indices=False)
        out = F.max_pool2d(out, kernel_size = 2) # 池化

        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)

        # 将多维的的数据平铺为一维
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

训练集

def train():
    # 学习率0.001
    learning_rate = 1e-3
    # 单次大小
    batch_size = 100
    # 总的循环
    epoches = 50
    lenet = LeNet()

    # 1、数据集准备
    # 这个函数包括了两个操作:transforms.ToTensor()将图片转换为张量,transforms.Normalize()将图片进行归一化处理
    trans_img = transforms.Compose([transforms.ToTensor()])
    # path = './data/'数据集下载后保存的目录,下载训练集
    trainset = MNIST('./data', train=True, transform=trans_img, download=True)
    # 构建数据集的DataLoader,
    # Pytorch自提供了DataLoader的方法来进行训练,该方法自动将数据集打包成为迭代器,能够让我们很方便地进行后续的训练处理
    # 迭代器(iterable)是一个超级接口! 是可以遍历集合的对象,
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=10)

    # 2、构建迭代器与损失函数
    criterian = nn.CrossEntropyLoss(reduction='sum')  # loss(损失函数)
    optimizer = optim.SGD(lenet.parameters(), lr=learning_rate)  # optimizer(迭代器)

    # 如果网络能在GPU中训练,就使用GPU;否则使用CPU进行训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #lenet.to("cpu")

    # 3、训练
    for i in range(epoches):
        running_loss = 0.
        running_acc = 0.
        for (img, label) in trainloader:  # 将图像和标签传输进device中
            optimizer.zero_grad()  # 求梯度之前对梯度清零以防梯度累加
            output=lenet(img)  # 对模型进行前向推理
            loss=criterian(output,label)  # 计算本轮推理的Loss值
            loss.backward()    # loss反传存到相应的变量结构当中
            optimizer.step()   # 使用计算好的梯度对参数进行更新
            running_loss+=loss.item()
            #print(output)
            _,predict=torch.max(output,1)  # 计算本轮推理的准确率
            correct_num=(predict==label).sum()
            running_acc+=correct_num.item()

        running_loss/=len(trainset)
        running_acc/=len(trainset)
        print("[%d/%d] Loss: %.5f, Acc: %.2f" % (i + 1, epoches, running_loss,100 * running_acc))

    return lenet

 测试集

def test(lenet):
    batch_size = 100
    trans_img = transforms.Compose([transforms.ToTensor()])
    testset = MNIST('./data', train=False, transform=trans_img, download=True)
    testloader = DataLoader(testset, batch_size, shuffle=False, num_workers=10)
    running_acc = 0.
    for (img, label) in testloader:
        output = lenet(img)
        _, predict = torch.max(output, 1)
        correct_num = (predict == label).sum()
        running_acc += correct_num.item()
    running_acc /= len(testset)
    return running_acc

主函数

if __name__ == '__main__':
    lenet = train()
    torch.save(lenet, 'lenet.pkl') # save model

    lenet = torch.load('lenet.pkl') # load model
    test_acc = test(lenet)
    print("Test Accuracy:Loss: %.2f" % test_acc)

结果:

Pytorch入门--详解Mnist手写字识别_第1张图片

 继上面对minst手写数字集进行训练和测试的完成,现将黑底白字的0~9的数字图片,进行识别

Pytorch入门--详解Mnist手写字识别_第2张图片Pytorch入门--详解Mnist手写字识别_第3张图片Pytorch入门--详解Mnist手写字识别_第4张图片Pytorch入门--详解Mnist手写字识别_第5张图片Pytorch入门--详解Mnist手写字识别_第6张图片Pytorch入门--详解Mnist手写字识别_第7张图片Pytorch入门--详解Mnist手写字识别_第8张图片Pytorch入门--详解Mnist手写字识别_第9张图片Pytorch入门--详解Mnist手写字识别_第10张图片Pytorch入门--详解Mnist手写字识别_第11张图片

识别函数代码如下

def practice(img_path):
    img = Image.open(img_path)
    img = img.convert('L')
    prac_img = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor()])
    pracset = MNIST('./data', train=True, transform=prac_img, download=True)
    img = prac_img(img)
    img = torch.reshape(img, (1, 1, 28, 28))
    lenet = torch.load('lenet.pkl')  # load model
    output = lenet(img)
    output = output.argmax(1)
    dict_target = pracset.class_to_idx
    dict_target = [indx for indx, vale in dict_target.items()]  # 获得标签字典
    print('识别类型为{}'.format(dict_target[output]))

主函数调用(图片路径按自身修改)

if __name__ == '__main__':
    practice('0.jpg')
    practice('1.jpg')
    practice('2.jpg')
    practice('3.jpg')
    practice('4.jpg')
    practice('5.jpg')
    practice('6.jpg')
    practice('7.jpg')
    practice('8.jpg')
    practice('9.jpg')

识别结果如下所示:

Pytorch入门--详解Mnist手写字识别_第12张图片

你可能感兴趣的:(pytorch,深度学习,卷积神经网络)