Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader

简介

在学习李沐在B站发布的《动手学深度学习》PyTorch版本教学视频中发现在操作使用PyTorch方面有许多地方看不懂,往往只是“动手”了,没有动脑。所以打算趁着寒假的时间好好恶补、整理一下PyTorch的操作,以便跟上课程。

学习资源:

  • B站up主:我是土堆的视频:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
  • PyTorch中文手册:(pytorch handbook)
  • Datawhale开源内容:深入浅出PyTorch(thorough-pytorch)

1 torchvision中的Datasets

这里使用torchvision中自带的一些数据集【CIFAR10数据集:32*32像素的图片分类数据集】

1.1 普通读取图片(PIL)

import torchvision

train_set = torchvision.datasets.CIFAR10(root = './dataset', train = True, download = True)
test_set = torchvision.datasets.CIFAR10(root = './dataset', train = False, download = True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset\cifar-10-python.tar.gz


99.6%

Files already downloaded and verified
# dir(train_set)
print(train_set.class_to_idx) # 为数字与类别的对应关系
train_set[0] # 可以看到输出有一个图片和一个数字target(label)

img, target = test_set[0] # img为对应的图片,target为对应的label
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

1.2 读取图片过程中使用transform

  • 由于数据集图片本身较小,所以本次只转为tensor即可
"""在读取数据集时设置transform参数"""
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root = './dataset', train = True, download = True, transform = dataset_transform)
test_set = torchvision.datasets.CIFAR10(root = './dataset', train = False, download = True, transform = dataset_transform)

test_set[0]
Files already downloaded and verified
Files already downloaded and verified





(tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],
          [0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],
          [0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],
          ...,
          [0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],
          [0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],
          [0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],
 
         [[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],
          [0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],
          [0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],
          ...,
          [0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],
          [0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],
          [0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],
 
         [[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],
          [0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],
          [0.1843, 0.1294, 0.1412,  ..., 0.1333, 0.1333, 0.1294],
          ...,
          [0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],
          [0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],
          [0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]]),
 3)

可以看到输出变成了一个三维tensor以及一个数字target

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('(5)')

for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
    
writer.close()

在pytorch环境下的命令行窗口中输入tensorboard --logdir='(5)'
Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader_第1张图片

2 torchvision中的Dataloader

在pytorch官网中搜索torch.utils.data.DataLoader

import torchvision
from torch.utils.data import DataLoader

# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train = False, transform = torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset = test_data, batch_size = 4, shuffle = True, num_workers = 0, drop_last = False)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
torch.Size([3, 32, 32])
3

上述的dataloader中参数batch_size = 4, 则dataloader会将4个图片和4个target分别打包,得到一个四维的tensor以及长度为4的一维tensor

for data in test_loader:
    imgs, targets = data
    print(img.shape)
    print(targets)
    break
torch.Size([4, 3, 32, 32])
tensor([7, 9, 6, 5])

batch_size改为64写入torchvision试试

2.1 drop_last

  • 如果dataloaderdrop_last = True,则第二张图不完全的一个batch不会读取,torchvisionstep少一步
writer = SummaryWriter("dataloader")
step = 0

test_loader = DataLoader(dataset = test_data, batch_size = 64, shuffle = True, num_workers = 0, drop_last = False)
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step += 1
    
writer.close()

Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader_第2张图片

Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader_第3张图片

2.2 shuffle=Ture

不同的epoch抓取到的数据不一样

for epoch in range(2):
    step = 0

    test_loader = DataLoader(dataset = test_data, batch_size = 64, shuffle = True, num_workers = 0, drop_last = True)
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step += 1

Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader_第4张图片

Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader_第5张图片

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习,人工智能)