Pytorch的数据读取主要包含三个类:
这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.
torch.utils.data.Dataset
是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__()
__len__()
第一个最为重要, 即每次怎么读数据. 以图片为例:
def __getitem__(self, index):
img_path, label = self.data[index].img_path, self.data[index].label
img = Image.open(img_path)
return img, label
值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms
里面, 本文中不多介绍, 常用的有Resize
, RandomCrop
, Normalize
, ToTensor
(这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor
, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()
里面用PIL来读图片, 而不是用skimage.io).
第二个比较简单, 就是返回整个数据集的长度:
def __len__(self):
return len(self.data)
torch.utils.data.DataLoader
类定义为:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=,
pin_memory=False,
drop_last=False
)
可以看到, 主要参数有这么几个:
dataset
: 即上面自定义的dataset.collate_fn
: 这个函数用来打包batchnum_worker
: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据这个类其实就是下面将要讲的DataLoaderIter
的一个框架, 一共干了两件事:
DataLoaderIter
,__iter__()
函数, 把自己 "装进" DataLoaderIter
里面.def __iter__(self):
return DataLoaderIter(self)
torch.utils.data.dataloader.DataLoaderIter
上面提到, DataLoader
就是DataLoaderIter
的一个框架, 用来传给DataLoaderIter
一堆参数, 并把自己装进DataLoaderIter
里。其实到这里就可以满足大多数训练的需求了, 比如
class CustomDataset(Dataset):
# 自定义自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
在for 循环里, 总共有三点操作:
dataloader
的__iter__()
方法, 产生了一个DataLoaderIter
DataLoaderIter
的__next__()
来得到batch, 具体操作就是, 多次调用dataset的__getitem__()
方法 (如果num_worker
>0就多线程调用), 然后用collate_fn
来把它们打包成batch. 中间还会涉及到shuffle
, 以及sample
的方法等. __next__()
抛出一个StopIteration
异常, for
循环结束, dataloader
失效.其实上面三个类已经可以搞定了, 仅供参考
class DataProvider:
def __init__(self, batch_size, is_cuda):
self.batch_size = batch_size
self.dataset = Dataset_triple(self.batch_size,
transform_=transforms.Compose(
[transforms.Scale([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]),
)
self.is_cuda = is_cuda # 是否将batch放到gpu上
self.dataiter = None
self.iteration = 0 # 当前epoch的batch数
self.epoch = 0 # 统计训练了多少个epoch
def build(self):
dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
self.dataiter = DataLoaderIter(dataloader)
def next(self):
if self.dataiter is None:
self.build()
try:
batch = self.dataiter.next()
self.iteration += 1
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch
except StopIteration: # 一个epoch结束后reload
self.epoch += 1
self.build()
self.iteration = 1 # reset and return the 1st batch
batch = self.dataiter.next()
if self.is_cuda:
batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
return batch
感谢以下链接提供的参考:
https://zhuanlan.zhihu.com/p/30934236
https://blog.csdn.net/u014380165/article/details/79167753