Pytorch提供多种数据集,要下载的话只需进入Pytorch官网,点击Docs下的torchvision,进入之后可以看到下方有多种常用数据集,如COCO、MNIST等,点击想要下载的数据集,进入后会有语句告知如何下载。
例如要下载CIFAR数据集,则通过
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
下载。第一个参数root为要下载到本地的地址;第二个参数如果为train=True则下载作为train数据集,如果为False则test数据集;中间参数可以不填;最后一个参数如果为download=True,则代表下载到本地。代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0])
注意第三个参数为是否要对数据集进行transform类转换,即通过transform的类进行操作,例如要将下载下来的图片转化为ToTensor格式,代码如下:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
dataloader相当于是对dataset处理的工具,其作用是对数据集进行预处理,最主要是确定每次进网络前多少张图片打包。进入pytorch官网,点击Docs下的Pytorch,然后点击左侧搜索,输入dataloader,点击第一个链接后,可以看到使用dataloader的方式,语句为
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
参数介绍:
dataset:之前下载的dataset或自制的数据集;
batch_size:每次加载的图片数,即打包尺寸;
shuffle:每轮测试是否乱序,为True则每轮所取数据顺序不同,为False则相同,默认为False,一般我们取True;
num_work:加载数据进程数,多进程可提升速度,但可能会有其它问题,默认为0;
drop_last:当每次取的batch_size取到最后一次,数据集中的图片数不足batch_size时是否舍去,为True则舍去,为False则保留。
代码如下:
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()