Pytorch的数据导入依靠 torch.utils.data.DataLoader 和torch.utils.data.Dataset(或torch.utils.data.IterableDataset)两个类来实现。
在 torch.utils.data 官方文档中提到,torch.utils.data.DataLoader 是pytorch 数据导入的核心工具,返回一个可迭代对象用于提取数据集中的数据。
DataLoader的参数如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
dataset 参数表示要加载数据的数据集对象。PyTorch支持两种不同类型的数据集:map-style datasets 和 iterable-style datasets。
Map-style 数据集应用__getitem()__ 和__len__()两种协议,是一种索引/键值与数据集样本的映射关系。也就是说 dataset 并不是读取了数据,而是读取了数据的索引/键值,后续通过这个索引/键值来访问数据。当我们访问 dataset[idx] 时,可以从磁盘读取第 idx 张图片和与之对应的标签,这一过程是对磁盘上的数据样本随机访问的情况。
其中__getitem()__是通过给定的索引/键值返回相应的数据集样本,len()是返回数据集的大小。
Iterable-style 数据集应用__iter()__协议,是一个可迭代的数据集。Iterable_style 数据集是读取了数据。这种类型的数据集特别适合于随机读取昂贵甚至不可能的情况,以及批量大小取决于所取数据的情况。
Note:对于多进程加载数据会出现重复读取相同数据情况。
e.g.:
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# [3, 4, 5, 6]
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
# [3, 3, 4, 4, 5, 5, 6, 6]
import os
import pickle
import numpy as np
def load_CIFAR_batch(filename):
'''load single batch of cifar'''
with open(filename, 'rb') as f:
datadict = pickle.load(f, encoding='latin1')
x = datadict['data']
print(x.shape)
y = datadict['labels']
X = x.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y=np.array(y)
return X, Y
def load_CIFAR10(path):
'''load all of cifar'''
xs = []
ys = []
for b in range(1,6):
f = os.path.join(path,'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(path, 'test_batch'))
return Xtr, Ytr, Xte, Yte
Xtr, Ytr, Xte, Yte = load_CIFAR10('D:\\数据集\\cifar-10\\')
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader
# Xtr, Xte展开
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)
# 创建子类
class MyMapstyleDataset(Dataset):
def __init__(self, datas, labels):
# super(MyMapstyleDataset).__init__()
self.datas = datas
self.labels = labels
def __getitem__(self, index):
data = torch.tensor(self.datas[index])
label = torch.tensor(self.labels[index])
return data, label
def __len__(self):
return len(self.datas)
Dataset = MyMapstyleDataset(Xtr_rows, Ytr)
train_dataloader = DataLoader(IterableDataset, shuffle=False, batch_size=8, num_workers=0)
# 创建子类
class MyIterstyleDataset(IterableDataset):
def __init__(self, start, end):
super(MyIterstyleDataset).__init__()
assert end > start, "Error"
self.start = start
self.end = end
# self.filepath = filepath
def _sample_gernerator(self, start, end):
# 当数据量大,无法一次Load进内存,可以通过这个函数以数据流的形式加载
for i in range(end-start):
sample = {"data":torch.tensor(Xtr[start+i, :]), "label":torch.tensor(Ytr[start+i, :])}
yield sample
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:# Single Worker
iter_start = self.start
iter_end = self.end
else:# Multiple Workers
per_work = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
sample_iterator = self._sample_gernerator(iter_start, iter_end)
return sample_iterator
def __len__(self):
return self.end - self.start
IterableDataset = MyIterstyleDataset(0, len(Xtr_rows))
Iter_train_dataloader = DataLoader(IterableDataset, shuffle=False, batch_size=8, num_workers=0)
Sampler 提供多种数据读取方式。对于 Iterable-Style 数据集,数据读取是自定义的。因此 Sampler 更多的应用在 MapStyle 数据集中。torch.utils.data.Sampler 类生成特定的索引/键值序列用于数据读取,是数据集索引/键值的 Iterable object。
class Sampler(Gerneric[T_co]):
def __init__(self, data_source) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
class SequentialSampler(Sampler[int])
'''顺序采样,顺序始终相同.'''
class RandomSampler(Sampler[int])
'''随机采样,"replacement"为True时,可能修改dataset大小,具体看"num_samples"'''
class SubsetRandomSampler(Sampler[int]):
'''按照给定的索引/键值序列采样'''
class WeightedRandomSampler(Sampler[int]):
'''对[0,...,len(weights)-1]的样本,按照weights的概率,随机抽取num_samples个样本'''
简单的 Sequential Sampler 和 shuffled Sampler 使用 DataLoader 的 shuffle 参数即可:True 为 shuffled;False 为 sequential。
同样用户可以自定义 sampler,提供一种__iter()__方法,每次迭代生成下一个索引/键值。
class BatchSampler(Sampler[list[int]]):
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler)
else:
return (len(self.sampler) + self.batch_size - 1)
DataLoader 支持将单个获取的数据自动整理为 batch ,其中 batchsize 和 drop_last 参数用于从 sampler 构成 batch_sampler 。
Note: 使用多进程从 Iterable-Style 数据集中读取数据时, drop_last 参数会丢弃每个 worker 的数据集副本的最后一个不完整 batch。
'''读取Map-Style数据集(Automatic batching Enable)'''
for incices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
'''读取Iterable-Style数据集(Automatic batching Enable)'''
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
当 batch_size(默认为1,1不是None)和 batch_sampler(默认)都是None时,Automatic batching禁用。
'''读取Map-Style数据集(Automatic batching Disnable)'''
for index in sampler:
yield collate_fn(dataset[index])
'''读取Iterable-Style数据集(Automatic batching Disnable)'''
for data in iter(dataset):
yield collate_fn(data)
当sampler或batch_sampler获取数据样本后,使用collate_fn 函数将样本列表整理成batch。
当 Automatic batching 禁用时: collate_fn 被每个单独的数据样本调用。在这种情况下,默认的 collate_fn 只是将NumPy 数组转换为PyTorch张量。
当 Automatic batching 禁用时: collate_fn 被每个数据样本列表调用。将读取的样本列表整理成batch。
e.g. Dataset 中每个样本是(image, class_index)的元组,collate_fn 返回image tensor和 label tensor。
关于默认的 collate_fn :
用户可以自定义 collate_fn 实现自定义 batching,比如,填充各种长的序列。
[1]. torch.utils.data.DataLoader官方文档
[2]. Pytorch IterableDataset的使用
[3]. 加载CIFAR_10