MNIST数据集可能是计算机视觉所接触的第一个图片数据集。而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可以有更好的区分度。
Fashion MNIST 数据集包含 60000 张图片的训练集和 10000 张图片的测试集。图片的大小为 28×28,共784个像素。像素的灰度值介于0~255之间的整数。
数据集分为10个类别,分别是:
torchvision.datasets
加载数据集import torch
import torchvision
import Image
# 使用 torchvision.datasets.FashionMNIST 下载数据
# data_path: 数据集的保存路径
# train: True下载训练集,False下载测试集
# transform: 图片预处理
# download: 是否从网络上下载数据
>>> data_path = "保存路径"
>>> train_data = torchvision.datasets.FashionMNIST(data_path, train=True, transform=None, download=True)
# 查看数据内容
>>> train_data.data.shape
torch.Size([60000, 28, 28])
>>> train_data.targets.shape
torch.Size([60000])
>>> train_data.classes
['T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot']
# 查看数据集的数据类型
>>> train_data
Dataset FashionMNIST
Number of datapoints: 60000
Root location: ..\pytorch\data
Split: Train
# Dataset FashionMNIST对象是 torch.utils.data.dataset.Dataset 的子集
>>> train_data.__class__.__mro__
(torchvision.datasets.mnist.FashionMNIST,
torchvision.datasets.mnist.MNIST,
torchvision.datasets.vision.VisionDataset,
torch.utils.data.dataset.Dataset,
typing.Generic,
object)
如果已经将数据集下载到了本地,可以直接解析文件来导入数据集。
数据集包括以下四个文件:
train-images-idx3-ubyte
train-labels-idx1-ubyte
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
这是四个二进制文件,其中的idx3表示有三个维度,idx1表示有一个维度。 这里使用常用的二进制解析库 struct 来解析文件。
import numpy as np
import struct
# 文件路径
data_path = r'路径'
file_names = ['t10k-images-idx3-ubyte',
't10k-labels-idx1-ubyte',
'train-images-idx3-ubyte',
'train-labels-idx1-ubyte']
def decode_idx3_ubyte(file):
"""
解析数据文件
"""
# 读取二进制数据
with open(file, 'rb') as fp:
bin_data = fp.read()
# 解析文件中的头信息
# 从文件头部依次读取四个32位,分别为:
# magic,numImgs, numRows, numCols
# 偏置
offset = 0
# 读取格式: 大端
fmt_header = '>iiii'
magic, numImgs, numRows, numCols = struct.unpack_from(fmt_header, bin_data, offset)
print(magic,numImgs,numRows,numCols)
# 解析图片数据
# 偏置掉头文件信息
offset = struct.calcsize(fmt_header)
# 读取格式
fmt_image = '>'+str(numImgs*numRows*numCols)+'B'
data = torch.tensor(struct.unpack_from(fmt_image, bin_data, offset)).reshape(numImgs, numRows, numCols)
return data
def decode_idx1_ubyte(file):
"""
解析标签文件
"""
# 读取二进制数据
with open(file, 'rb') as fp:
bin_data = fp.read()
# 解析文件中的头信息
# 从文件头部依次读取两个个32位,分别为:
# magic,numImgs
# 偏置
offset = 0
# 读取格式: 大端
fmt_header = '>ii'
magic, numImgs = struct.unpack_from(fmt_header, bin_data, offset)
print(magic,numImgs)
# 解析图片数据
# 偏置掉头文件信息
offset = struct.calcsize(fmt_header)
# 读取格式
fmt_image = '>'+str(numImgs)+'B'
data = torch.tensor(struct.unpack_from(fmt_image, bin_data, offset))
return data
train_set = (decode_idx3_ubyte(os.path.join(data_path, file_names[0])),
decode_idx1_ubyte(os.path.join(data_path, file_names[1])))
test_set = (decode_idx3_ubyte(os.path.join(data_path, file_names[2])),
decode_idx1_ubyte(os.path.join(data_path, file_names[3])))
运行结果:
2051 10000 28 28
2049 10000
2051 60000 28 28
2049 60000
使用 pytorch
的 DataLoader
对象分批读取数据。
# 将data和label张量封装为数据类,可通过第一个维度来索引每一个样本。
train_data = torch.utils.data.TensorDataset(*train_set)
test_data = torch.utils.data.TensorDataset(*test_set)
# 创建数据加载器,小批次读取数据
batch_size = 5050
train_Loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 分批读取
for X,y in train_Loader:
print(X.shape, y.shape)
运行结果:
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([5050, 28, 28]) torch.Size([5050])
torch.Size([4450, 28, 28]) torch.Size([4450])