【Pytorch学习笔记】MNIST数据集的训练及简单应用(二)

上篇做了简单的介绍以及数据集的训练:请点这里

下篇将借助OpenCV试着实际使用一下我们训练的模型

首先:对于上篇描述的训练好的模型model,将其保存下来,使用时再加载出来

#训练模型
for epoch in range(10):    #训练10轮
    train(epoch)           #执行训练
    test()                 #每轮训练完都测试一下正确率

#保存模型
path = "C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth"
torch.save(model, path)

#加载模型
model = torch.load("C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth")

我们想要实现的是,通过摄像头捕获图片,实时识别出数字。大致有几个步骤:

1.调用摄像头,拿出每一帧

2.将每一帧输入进模型,得到输出的预测结果

3.将摄像头图像和预测值实时显示出来

这部分代码都包含在一个循环里面:

这其中包括了:获取每一帧图像,转换为灰度图像,反向二值化,改变其大小为28*28,最后还要转换为张量才能送到模型中去。(由于数据集中数字是黑底白字,我们日常手写是白底黑字,所以需要反向二值化操作)

然后是使用模型进行预测。

最后输出预测结果即可。

#1.调用摄像头,拿出每一帧   

    cap = cv2.VideoCapture(0)    #定义视频来源为摄像头
    while 1:
        ret, frame = cap.read()  # 摄像头读取,ret为是否读取成功,frame为视频的每一帧图像
        #展示原图像
        frame = cv2.resize(frame, (300, 200))
        cv2.imshow("source", frame)
        #展示灰度,二值化后图像
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)    #灰度化
        res, frame = cv2.threshold(frame, 90, 255, cv2.THRESH_BINARY_INV)   #反向二值化
        cv2.imshow("gray", frame)
        #展示输入模型的图像(方形)
        frame = cv2.resize(frame, (140, 140))
        cv2.imshow("28*28", frame)
        cv2.waitKey(100)       #延时,控制帧率

        frame = cv2.resize(frame, (28, 28))    #转换成28*28大小

        #两次升维,使其能送入模型
        testimg = torch.unsqueeze(testimg, dim=0)
        testimg = torch.unsqueeze(testimg, dim=0)
        #转换为浮点数
        testimg = testimg.to(torch.float32)

#2.将每一帧输入进模型,得到输出的预测结果
        #加载模型
        model = torch.load("C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth")
        predimg = model(testimg)                           #进行预测
        _, pred = torch.max(predimg.data, dim=1)           #获得最大值

        print('the predict num is', int(pred.data[0]))     #输出结果

最后结果如图:

【Pytorch学习笔记】MNIST数据集的训练及简单应用(二)_第1张图片

 【Pytorch学习笔记】MNIST数据集的训练及简单应用(二)_第2张图片

 【Pytorch学习笔记】MNIST数据集的训练及简单应用(二)_第3张图片

 

总的来看,当送进模型的图像(左上角)周围全黑且数字在中央时,准确率较高,一旦周围有杂质或是数字不在中央或是大小不合适,准确率就十分低了。

以上就是我正在入门pytorch的一些学习过程,还请各位指出其中错误指出,谢谢阅读!

你可能感兴趣的:(python,深度学习,pytorch,opencv,计算机视觉)