python中如何导入torch_[PyTorch 笔记] PyTorch中的数据导入

[Deep in PyTorch] PyTorch中的数据导入

简介

Pytorch导入数据主要依靠torch.utils.data.DataLoader和torch.utils.data.Dataset这两个类来完成。其中

torch.utils.data.Dataset:对数据集进行抽象,变为的一个类。其结果是一个可迭代对象,可用于迭代提取数据集中的数据。

torch.utils.data.DataLoader:接受torch.utils.data.Dataset作为输入,得到DataLoader,其将Dataset与不同的取样器(Loader)联合,如SequentialSampler和RandomSampler,也可以控制使用单进程还是多进程的迭代器。实质是提供了更为高级的可定制提取数据的数据读取器。

因此,我们需要了解,如果生成torch.utils.data.Dataset和torch.utils.data.DataLoader

Dataset

torch.utils.data.Dataset可通过两种方式生成,

一种是根据 pytorch built-in也就是内置的数据集生成Dataset,这种方式在运行时候会自动从pytorch的网站上下载指定的数据集;

另一种是根据自己提供的数据生成Dataset,这种方式需要我们提供数据,且重写Dataset类中的一些方法。

Built-in dataset

使用Built-in 的dataset需要导入torchvision.dataset。其包含的数据集可以在以下url中找到: https://pytorch.org/docs/stable/torchvision/datasets.html

有一些数据集需要提前先下载好的,在上面的url里面有说明。现在我们来以cifar10为例子生成Dataset。

example:

import torch

import torchvision

cf10_data = torchvision.datasets.CIFAR10('dataset/cifar/', download=True)

Custom dataset

而其实,我们更关注的是,按照我们的数据,自定义Dataset类。 任何自定义的数据集都要继承自torch.utils.data.Dataset,然后要重写两个函数:__len__(self)和__getitem__(self, idx),因此自定义数据集的格式大概为:

from torch.utils.data import Dataset

class MyDataset(Dataset):

def __init__(self):

pass

def __len__(self):

pass

def __getitem(self, idx):

pass

其中: __len__:负责返回最大元素个数。 __getitem__返回第idx个元素。

实例:

class PigsDataset(Dataset):

def __init__(self, root_dir, size=(224,224)):

self.files = glob(root_dir)

self.size = size

def __len__(self):

return len(self.files)

def __getitem__(self, idx):

img = np.asarray(Image.open(self.files[idx]).resize(self.size))

label = self.files[idx].split('/')[-2]

return img, label

DataLoader

实现了torch.utils.data.Dataset后,我们需要构建一个DataLoader

DataLoader的参数:

DataLoader(dataset,

batch_size=1,

shuffle=False,

sampler=None,

batch_sampler=None,

num_workers=0,

collate_fn=None,

pin_memory=False,

drop_last=False,

timeout=0,

worker_init_fn=None,

*,

prefetch_factor=2,

persistent_workers=False)

其中比较常用的为:dataset:就是torch.utils.data.Dataset对象

batch_size:批的大小,当size=1时,相当于每次只处理一个,size越大,每批处理的数据就越多。

shuffle:决定是是使用sequential还是shuffled的 sampler

sampler / batch_sampler:自定义sampler。sampler 每次返回的指定值,作为下次提取数据时的index / key

num_workers:在导入数据时,是单线程还是多线程,默认是单线程。(注意,不建议在使用多线程的情况下返回CUDA的tensor)

collate_fn:用于自定义sample 如何形成 batch sample 的函数。

使用collate_fn

在两种情况下collate_fn的效果是不同的:Automatic batching 不可用

collate_fn会被每个单独的data样本调用,通常的操作,只是将数据中的numpy的数组转为torch的tensor类型。Automatic batching 可用

这时候的collate_fn的输入是 list对象,是一批的data样本。 list的长度是一个batch_size,而list中每个元素都是__getitem__得到的结果

def my_collate(batch):

# The input is a list of tuple (img, label)

# [(img0, label0), (img1, label1), (img2, label2)...]

batch.sort(key=lambda x: len(x[1]), reverse=True)

img, label = zip(*batch)

pad_label = []

lens = []

max_len = len(label[0])

for i in range(len(label)):

temp_label = [0] * max_len

temp_label[:len(label[i])] = label[i]

pad_label.append(temp_label)

lens.append(len(label[i]))

return img, pad_label, lens

因此,我们可以举一个DataLoader的实例:

from torch.utils.data import DataLoader

train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate)

你可能感兴趣的:(python中如何导入torch_[PyTorch 笔记] PyTorch中的数据导入)