[Deep in PyTorch] PyTorch中的数据导入
简介
Pytorch导入数据主要依靠torch.utils.data.DataLoader和torch.utils.data.Dataset这两个类来完成。其中
torch.utils.data.Dataset:对数据集进行抽象,变为的一个类。其结果是一个可迭代对象,可用于迭代提取数据集中的数据。
torch.utils.data.DataLoader:接受torch.utils.data.Dataset作为输入,得到DataLoader,其将Dataset与不同的取样器(Loader)联合,如SequentialSampler和RandomSampler,也可以控制使用单进程还是多进程的迭代器。实质是提供了更为高级的可定制提取数据的数据读取器。
因此,我们需要了解,如果生成torch.utils.data.Dataset和torch.utils.data.DataLoader
Dataset
torch.utils.data.Dataset可通过两种方式生成,
一种是根据 pytorch built-in也就是内置的数据集生成Dataset,这种方式在运行时候会自动从pytorch的网站上下载指定的数据集;
另一种是根据自己提供的数据生成Dataset,这种方式需要我们提供数据,且重写Dataset类中的一些方法。
Built-in dataset
使用Built-in 的dataset需要导入torchvision.dataset。其包含的数据集可以在以下url中找到: https://pytorch.org/docs/stable/torchvision/datasets.html
有一些数据集需要提前先下载好的,在上面的url里面有说明。现在我们来以cifar10为例子生成Dataset。
example:
import torch
import torchvision
cf10_data = torchvision.datasets.CIFAR10('dataset/cifar/', download=True)
Custom dataset
而其实,我们更关注的是,按照我们的数据,自定义Dataset类。 任何自定义的数据集都要继承自torch.utils.data.Dataset,然后要重写两个函数:__len__(self)和__getitem__(self, idx),因此自定义数据集的格式大概为:
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
pass
def __len__(self):
pass
def __getitem(self, idx):
pass
其中: __len__:负责返回最大元素个数。 __getitem__返回第idx个元素。
实例:
class PigsDataset(Dataset):
def __init__(self, root_dir, size=(224,224)):
self.files = glob(root_dir)
self.size = size
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
img = np.asarray(Image.open(self.files[idx]).resize(self.size))
label = self.files[idx].split('/')[-2]
return img, label
DataLoader
实现了torch.utils.data.Dataset后,我们需要构建一个DataLoader
DataLoader的参数:
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
*,
prefetch_factor=2,
persistent_workers=False)
其中比较常用的为:dataset:就是torch.utils.data.Dataset对象
batch_size:批的大小,当size=1时,相当于每次只处理一个,size越大,每批处理的数据就越多。
shuffle:决定是是使用sequential还是shuffled的 sampler
sampler / batch_sampler:自定义sampler。sampler 每次返回的指定值,作为下次提取数据时的index / key
num_workers:在导入数据时,是单线程还是多线程,默认是单线程。(注意,不建议在使用多线程的情况下返回CUDA的tensor)
collate_fn:用于自定义sample 如何形成 batch sample 的函数。
使用collate_fn
在两种情况下collate_fn的效果是不同的:Automatic batching 不可用
collate_fn会被每个单独的data样本调用,通常的操作,只是将数据中的numpy的数组转为torch的tensor类型。Automatic batching 可用
这时候的collate_fn的输入是 list对象,是一批的data样本。 list的长度是一个batch_size,而list中每个元素都是__getitem__得到的结果
def my_collate(batch):
# The input is a list of tuple (img, label)
# [(img0, label0), (img1, label1), (img2, label2)...]
batch.sort(key=lambda x: len(x[1]), reverse=True)
img, label = zip(*batch)
pad_label = []
lens = []
max_len = len(label[0])
for i in range(len(label)):
temp_label = [0] * max_len
temp_label[:len(label[i])] = label[i]
pad_label.append(temp_label)
lens.append(len(label[i]))
return img, pad_label, lens
因此,我们可以举一个DataLoader的实例:
from torch.utils.data import DataLoader
train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate)