%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
SVG是一种无损格式 – 意味着它在压缩时不会丢失任何数据,可以呈现无限数量的颜色。 d2l.use_svg_display() 意思是使用svg来显示图片,这样清晰度高一些。
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
trans = transforms.ToTensor() 暂且按下不表。看名字大概就是把数据转换成tensor。
torchvision.datasets.FashionMNIST是PyTorch自带的读取MNIST的库,其关键字如下:
root (string) – Root directory of dataset where FashionMNIST/raw/train-images-idx3-ubyte
and FashionMNIST/raw/t10k-images-idx3-ubyte
exist.
用于储存训练数据(以上两个文件)的目录
这里如果直接从资源管理器复制目录会出现以下错误:
SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \xXX escape
百度可知,是因为字符串里面的\右斜杠被识别为转义字符
解决方法有两种:
1.在字符串前,增加r,保持字符串的原始含义
root=r"../data"
2.把右斜杠\改成左斜杠/ 即可
train (bool, optional) – If True, creates dataset from train-images-idx3-ubyte
, otherwise from t10k-images-idx3-ubyte
.
这里是用布尔值指定是训练集(train-images-idx3-ubyte
)还是测试集(t10k-images-idx3-ubyte
)
download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
如果实现准备好了数据集,可以把这个布尔值指定为False,否则就自动下载一份数据集到目录
这里本地下载文件非常麻烦,而且可直接下载,无需特别手段,所以可以指定好目录直接下载。
若非要下载,首先文件会存储在...\data\FashionMNIST\raw目录下,而且必须有八个文件存在(四个数据包和他们的解压文件)才能通过检测。
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
与上文的
trans = transforms.ToTensor() 对上了
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
print(len(mnist_train), len(mnist_test))
print(mnist_train[0][0].shape)
这里下面的输出为
torch.Size([1, 28, 28])
这表明这第一张图片是一个黑白图片,channel数为1,长宽都是28px
def get_fashion_mnist_labels(labels): #@save
"""返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
get_fashion_mnist_labels(labels)函数
可以输入labels为一个数组,其中的每一项都是0-9的数字,对应标签中的10项,然后输出一个数组,其中每一项都是一个表示类别的字符串
show_images(imgs, num_rows, num_cols, titles=None, scale=1.5)函数
这个函数是用来输出个体对应的图片。
img是输入的图片;
num_rows, num_cols是输出图片时,想把图片输出的行列数,如num_rows=2, num_cols=9就是输出2行9列图片;
titles是显示在图片上方的标签,用于把数据集的标签显示出来;
scale是缩放比例,可以把输入的图片(28*28太小了,放大一些会比较好)进行放缩在输出。
函数内部的流程:
1.先把输出图像的尺寸定下来,也就是num_cols * scale, num_rows * scale,输出给figsize
2.创建画图用的画布。这里注意_, axes 是两个变量!但是_变量是作为占位,不被使用。
这里使用了subplots这个命令,用法:
fig, ax = plt.subplots()这样使用,其中要把图和锚(规定图像的坐标轴大小、图片数量等)一起传输进去。
.flatten是把二维数组变成一维数组,不清楚为什么要这么干
询问了一下Bing,这是为了遍历数组中每个图的坐标。
在下面的这个循环中,i是每个子图的索引,后面的(ax,img)则是每个子图的内容。为了实现这个功能,使用了zip和enumerate两个函数,他们的用法如下:
zip(a,b)可以把两个数组打包为一个数组,例如:a=[a,b,c,d],b=[1,2,3,4],那么list(zip(a,b))=[(a,1),(b,2),(c,3),(d,4)]
enumerate是为了一次遍历两个元素,可以把i作为索引,而后面的(ax,img)是每项的内容。
my_list = ['a', 'b', 'c']
for i, value in enumerate(my_list):
print(f'Index: {i}, Value: {value}')
输出结果为:
Index: 0, Value: a
Index: 1, Value: b
Index: 2, Value: c
然后,就把图片展示出来:
在接下来的循环中,判断图片img是否是张量,如果是,则将其转换为 NumPy 数组并使用 imshow
函数绘制;否则,直接使用 imshow
函数绘制。
接下来分别获取了axes的x、y轴,并将其设置为不可见。
最后函数返回axes的值
简单记录一下自己对DataLoader的理解:
DataLoader是torch.utils.data模块的一个类,它的用法是:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=
其中比较重要的:
dataset:就是输入的数据集
batch_size:顾名思义,一次读取的数据个数
shuffle:(数据类型 bool)是否打乱数据?设置为True就会打乱
timer = d2l.Timer()
for X, y in train_iter:
continue
f'{timer.stop():.2f} sec'
这里d2l.Timer只是一个计时器,可以使用timer.start()和timer.stop()来记录这两个命令时间代码运行的时间。