pytorch读取MNIST数据集并显示

pytorch读取MNIST数据集并显示

#直接下载数据集并读取

1 代码:

import torch
import torchvision
import matplotlib.pyplot as plt #用于显示图片

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

#忽略警告
import warnings
warnings.filterwarnings('ignore')

#选择运行设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#下载数据集    28*28=784
dataset_train = torchvision.datasets.MNIST(root='./dataset_method_1', train=True, transform=torchvision.transforms.ToTensor(), download=True)
dataset_test = torchvision.datasets.MNIST(root='./dataset_method_1', train=False, transform=torchvision.transforms.ToTensor(), download=False)

#将数据集按批量大小加载到数据集中
data_loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=100, shuffle=True)  #600*100*([[28*28],x])
data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=100, shuffle=False)

#for epoch in range(5):  #一共五个周期,其中一个周期(len(dataset_train)=60000)/(batch_size=100)=(len(dataset_train)=600)个批量
for i, (images, labels) in enumerate(data_loader_train):

    #print(i, images[0].shape, labels[0].shape)
    '''
        每一个周期,共600个批次(i=0~599);
        data_loader_train包含600个批次,包括整个训练集;
        每一批次一共100张图片,对应100个标签, len(images[0])=1;
        images包含一个批次的100张图片(image[0].shape=torch.Size([1,28,28])),labels包含一个批次的100个标签,标签范围为0~9
    '''

    #每100个批量绘制绘制最后一个批量的所有图片
    if (i + 1) % 100 == 0:
        print('batch_number [{}/{}]'.format(i + 1, len(data_loader_train)))
        for j in range(len(images)):
            image = images[j].resize(28, 28) #将(1,28,28)->(28,28)
            plt.imshow(image)  # 显示图片,接受tensors, numpy arrays, numbers, dicts or lists
            plt.axis('off')  # 不显示坐标轴
            plt.title("$The {} picture in {} batch, label={}$".format(j + 1, i + 1, labels[j]))
            plt.show()

2 文件结构

数据集在dataset_method_1中,代码位于method_1.py中
pytorch读取MNIST数据集并显示_第1张图片

3 运行结果(部分)

pytorch读取MNIST数据集并显示_第2张图片
pytorch读取MNIST数据集并显示_第3张图片

你可能感兴趣的:(pytorch,深度学习,pytorch,MNIST,数据集读取与显示)