图像分类数据集

图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
#transforms 对数据操作的包
from torchvision import transforms
#存在d2l
from d2l import torch as d2l

d2l.use_svg_display()

通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式
# 并处以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="./data/FashionMNIST", train=True,transform=trans,download = False)
mnist_test = torchvision.datasets.FashionMNIST(root="./data/FashionMNIST", train=False,transform=trans,download = False)
len(mnist_train), len(mnist_test)
(60000, 10000)
mnist_train[0][0].shape
torch.Size([1, 28, 28])
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签。"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    #这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images),figsize=(12,12))
    for f,img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axus.get_yaxis().set_visible(False)
    plt.show()

def show_image(imgs, num_row=1, num_col=1, titles=None, scale=1.5):
    # 设置图片大小
    figsize = (num_col * scale,  num_row * scale)
    
    # 这里的 _ 表示忽略不使用的变量、即fig
    _, axes = d2l.plt.subplots(num_row, num_col, figsize=figsize)
    
    for i, (img, label)  in  enumerate(zip(imgs, titles)):
        # 计算图片的位置、需要用到整除和除余
        xloc,   yloc   =   i // num_col,   i % num_col
        # 判断传入的图片是否为张量
        if torch.is_tensor(img):
            axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())
        else:
            axes[xloc, yloc].imshow(img)
        # 设置标题并取消横纵坐标上的刻度
        axes[xloc, yloc].set_title(label)
        axes[xloc, yloc].set(xticks=[], yticks=[])
batch_size = 18
# data.DataLoader( )函数的作用在于根据传入的数据集和批量大小来返回小批量数据集
X, y = next(iter(data.DataLoader(mnist_train, batch_size=batch_size)))
show_image(X.reshape(batch_size, 28, 28), num_row=2, num_col=9,
          titles=get_fashion_mnist_labels(y))

图像分类数据集_第1张图片

batch_size = 256

def get_dataloader_workers():
    """使用4个进程来读取的数据。"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

timer = d2l.Timer()
for X,y in train_iter:
    continue
f'{timer.stop():.2f} sec'

你可能感兴趣的:(动手深度学习,java,计算机视觉,python)