《深度学习之Pytorch实战计算机视觉》:手写数字识别

import torchvision
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.autograd import Variable

transform = transforms.Compose([transforms.ToTensor(),
                                # transforms.Lambda(lambda x: x.repeat(3,1,1)),
                                transforms.Normalize([0.5], [0.5])])
data_train = datasets.MNIST(root="./data/",
                            transform=transform,
                            train=True,
                            download=True)
data_test = datasets.MNIST(root="./data/",
                           transform=transform,
                           train=False)
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size=64,
                                                shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size=64,
                                               shuffle=True)
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64)])
plt.imshow(img)

transpose(1,2,0)将通道数换到列的位置,行换到了通道的位置,具体可以看Pytorch里面的一些小细节
images,labels =next(iter(data_loader_train))表示这句可能有点没看懂

根据我的理解:主要就是读取data_loader_train中的数据,其中数据存在两种形式一个是图片形式,一个标签形式,并且两者不能进行位置的更换

torchvision.utils.make_grid 的参数就是每次的装载数据,其中装载数据都是4维度,从前往后分别为batch _size、channel 、height 和weight , 分别对应一个批次中的数据个数、每张图片的色彩通道数、每张图片的高度和宽度。在通过torchvision.utils.make_grid 之后,图片的维度变成了( channel,height,weight ),这个批次的图片全部被整合到了一起, 所以在这个维度中对应的值也和之前不一样了, 但是色彩通道数保持不变。

Matplotlib显示数据维度形式(height,weight,channel)

这部分主要就是用于显示MNIST训练集中的图像

你可能感兴趣的:(《深度学习之Pytorch实战计算机视觉》:手写数字识别)