Pytorch下dataset和dataloader的下载和使用

dataset和dataloader的使用

  • 1.dataset的使用
  • 2.dataloader的使用

1.dataset的使用

Pytorch提供多种数据集,要下载的话只需进入Pytorch官网,点击Docs下的torchvision,进入之后可以看到下方有多种常用数据集,如COCO、MNIST等,点击想要下载的数据集,进入后会有语句告知如何下载。
例如要下载CIFAR数据集,则通过

torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

下载。第一个参数root为要下载到本地的地址;第二个参数如果为train=True则下载作为train数据集,如果为False则test数据集;中间参数可以不填;最后一个参数如果为download=True,则代表下载到本地。代码如下:

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

print(test_set[0])

注意第三个参数为是否要对数据集进行transform类转换,即通过transform的类进行操作,例如要将下载下来的图片转化为ToTensor格式,代码如下:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

2.dataloader的使用

dataloader相当于是对dataset处理的工具,其作用是对数据集进行预处理,最主要是确定每次进网络前多少张图片打包。进入pytorch官网,点击Docs下的Pytorch,然后点击左侧搜索,输入dataloader,点击第一个链接后,可以看到使用dataloader的方式,语句为

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)

参数介绍:

dataset:之前下载的dataset或自制的数据集;
batch_size:每次加载的图片数,即打包尺寸;
shuffle:每轮测试是否乱序,为True则每轮所取数据顺序不同,为False则相同,默认为False,一般我们取True;
num_work:加载数据进程数,多进程可提升速度,但可能会有其它问题,默认为0;
drop_last:当每次取的batch_size取到最后一次,数据集中的图片数不足batch_size时是否舍去,为True则舍去,为False则保留。

代码如下:

import torchvision

# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1

writer.close()

你可能感兴趣的:(pytorch,pytorch,深度学习,python)