pytorch读取mnist数据集并进行展示

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn import functional
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data_loader = DataLoader(dataset=train_data, shuffle=True, batch_size=64)     #使用加载器加载,可以降低内存的消耗
test_data = datasets.MNIST(root='./data', train=False, transform=transform)
test_data_loader = DataLoader(dataset=test_data, shuffle=True, batch_size=64)

examples = enumerate(train_data_loader)
batch_idx, (example_data, example_label) = next(examples)


#fig = plt.figure()

for i in range(8):
    plt.subplot(4, 2, i + 1)
    plt.tight_layout()  #自动调整子图参数,使之填充整个图像区域
	
	#这里的example_data[i][0]后面的0表示的可能是色彩级数,如RGB有3个,灰度图就只有一个
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth:{}".format(example_label[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

在神经网络中,数据一般会被分为两个部分,一个部分用于训练,另一部分用于测试,查看训练的网络的效果如何。
那么第一步就是数据下载并进行处理,datasets里面包含有多个数据集,Mnist就是其中的一个,其有60000张28*28的灰度图像,设置download = “true”,可以将数据下载到本地来进行操作

  1. transforms.Compose可以将一些转换函数组合到一起
  2. Normalize([0.5], [0.5])对张量进行归一化,这里的两个0.5分别表示均值和方差。因图像是灰色的只有一个通道,如果有多个通道,如RGB图像,则需要有多个通道数字,如3个通道,应该是Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
  3. download的参数控制是否需要下载,如果该.data目录下已经有MNIST,可选择False
  4. 用DataLoader得到生成器,这可以节省内存
  5. torchvision是torch的图像工具包

在得到DataLoader迭代器后,我们可以使用next进行迭代取出数据,在DataLoader里面有一个batch_size参数用于控制每一次迭代我们从里面选择出多少数据,如这里的batch_size我们设置为64,则其一次迭代取出的就是64个数据

#这里的examples包含64个手写体数字以及其他的一些参数
examples = enumerate(train_data_loader)
#next()函数获得迭代的下一个值,返回的有批号,样本数据以及标签,很明显,enumerate()返回的应该是下标0,我么要移动到下标1上才能读取数据
#注意:样本数据(Tensor类型)和标签(Tensor类型)放置在一个列表中,需要加上括号
batch_idx, (example_data, example_label) = next(examples)

我们打开调试,具体看看每个里面都是什么:

  1. batch_idx
  2. example_data
    pytorch读取mnist数据集并进行展示_第1张图片

example_label:
pytorch读取mnist数据集并进行展示_第2张图片
可以看到,我们的example_data和example_label里面都包含了很多的属性

展示我们的手写体数据:
pytorch读取mnist数据集并进行展示_第3张图片

你可能感兴趣的:(机器学习)