(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第四天:单例测试

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第四天,主要学习导入模型并进行单例测试。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108098147

第二天(加载 MNIST 数据集):https://blog.csdn.net/qq_36627158/article/details/108119048

第三天(训练模型):https://blog.csdn.net/qq_36627158/article/details/108163693

第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108183655

 

 

 

 

2. Code(mnist_classify.py)

感谢 凯神 提供的代码与耐心指导!

from torchvision import transforms
from PIL import Image, ImageOps
from mnist_train import *


classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
model = Net()


def load_checkpoint(checkpoint_path, model):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['model'])


if __name__ == '__main__':
    load_checkpoint(
        'module/pytorch-mnist-batch-128-1407.pth',
        model
    )

    model = model.to(device)
    model.eval()

    img = Image.open("/home/ubuntu/Downloads/C6/3.jpg")
    img = ImageOps.invert(img)

    # rgb -> single channel image
    if len(img.split()) > 1:
        img = img.split()[0]

    plt.figure()
    plt.imshow(img)
    plt.show()

    trans = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
    img = trans(img)

    img = img.to(device)

    img = img.unsqueeze(0)

    output = model(img)
    prob = F.softmax(output, dim=1)

    max_value, max_index = torch.max(prob, 1)

    pred_class = classes[max_index.item()]
    print 'predicted class is', pred_class, ', probability is', round(max_value.item(), 6) * 100

 

 

 

3. Details

1、im.split()

r, g, b=im.split()   该函数用来将RGB图片分割成三个通道的图片

Python-Image 基本的图像处理操作

 

2、torch.unsqueeze()

为 Torch Tensor 添加维度

https://blog.csdn.net/xiexu911/article/details/80820028

你可能感兴趣的:(Python,学习,人工智能,PyTorch,LeNet,MNIST,python)