我们之所以要自己写DataSet,是因为原本提供的DataSet无法满足我们的实际需求,此时就需要我们自定义。通过继承 torch.utils.data.Dataset来实现,
在继承的时候,需要 override
三个方法:
__init__
: 用来初始化一些有关操作数据集的参数__getitem__:定义数据获取的方式(包括读取数据,对数据进行变换等),
该方法支持从 0 到 len(self)-1的索引。obj[index]
等价于obj.__getitem__
__len__:获取数据集的大小。len(obj)
等价于obj.__len__()
dataset中应尽量只包含只读对象,避免修改任何可变对象。如下面例子中的self.num
可能在多进程下出问题:
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # read data times
def __getitem__(self, index):
self.num += 1
return self.datas[index]
自定义DataSet的框架:
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
DataSet创建及使用的完整流程如下:
代码示例如下:
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for data in dataloader:
....
# 以下两个代码是等价的
for data in dataloader:
...
# 等价与
iterr = iter(dataloader)
while True:
try:
next(iterr)
except StopIteration:
break
在DataLoader
中,iter(dataloader)
返回的是一个 DataLoaderIter
对象, 这个才是我们一直 next
的 对象。
这个DataLoaderIter
其实就是DataLoader
类的__iter__()
方法的返回值:
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if batch_sampler is None:
if sampler is None:
if shuffle:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
sampler = RandomSampler(dataset)
else:
# dataset.__len__() 在 Sampler 中被使用。
# 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
sampler = SequentialSampler(dataset)
# Sampler 是个迭代器,一次之只返回一个 索引
# BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
DataLoader()的各个参数含义如下:
1. dataset:加载的数据集,这个从DataSet()函数而来。
2. batch_size:batch size,设定每次训练迭代时加载的数据量。
3. shuffle::是否将数据打乱
4. sampler: 样本抽样
5. num_workers:使用多进程加载的进程数,0代表不使用多进程,设定多进程可以使得加载数据时更加快速。
6. collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
7. pin_memory:是否将数据(tensor)保存在pin memory区,pin memory中的数据转到GPU中会快一些
8. drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,False表示不丢弃。
注意,这个DataLoaderIter
中*init(self, loader)*中的loader就是对应的DataLoader类的实例。
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
# loader 是 DataLoader 对象
self.dataset = loader.dataset
# 这个留在最后一个部分介绍
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
# 表示 开 几个进程。
self.num_workers = loader.num_workers
# 是否使用 pin_memory
self.pin_memory = loader.pin_memory
self.done_event = threading.Event()
# 这样就可以用 next 操作 batch_sampler 了
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
# 用来放置 batch_idx 的队列,其中元素的是 一个 list,其中放了一个 batch 内样本的索引
self.index_queue = multiprocessing.SimpleQueue()
# 用来放置 batch_data 的队列,里面的 元素的 一个 batch的 数据
self.data_queue = multiprocessing.SimpleQueue()
# 当前已经准备好的 batch 的数量(可能有些正在 准备中)
# 当为 0 时, 说明, dataset 中已经没有剩余数据了。
# 初始值为 0, 在 self._put_indices() 中 +1,在 self.__next__ 中减一
self.batches_outstanding = 0
self.shutdown = False
# 用来记录 这次要放到 index_queue 中 batch 的 idx
self.send_idx = 0
# 用来记录 这次要从的 data_queue 中取出 的 batch 的 idx
self.rcvd_idx = 0
# 因为多线程,可能会导致 data_queue 中的 batch 乱序
# 用这个来保证 batch 的返回 是 idx 升序出去的。
self.reorder_dict = {}
# 这个地方就开始 开多进程了,一共开了 num_workers 个进程
# 执行 _worker_loop , 下面将介绍 _worker_loop
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
for _ in range(self.num_workers)]
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
if self.pin_memory:
in_data = self.data_queue
self.data_queue = queue.Queue()
self.pin_thread = threading.Thread(
target=_pin_memory_loop,
args=(in_data, self.data_queue, self.done_event))
self.pin_thread.daemon = True
self.pin_thread.start()
# prime the prefetch loop
# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中。
for _ in range(2 * self.num_workers):
self._put_indices()
def __len__(self):
return len(self.batch_sampler)
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
# 说明没有 剩余 可操作数据了, 可以停止 worker 了
self._shutdown_workers()
raise StopIteration
while True:
# 这里的操作就是 给 乱序的 data_queue 排一排 序
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self.data_queue.get()
# 一个 batch 被 返回,batches_outstanding -1
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
# 返回的时候,再向 indice_queue 中 放下一个 (batch_idx, sample_indices)
return self._process_next_batch(batch)
next = __next__ # Python 2 compatibility
def __iter__(self):
return self
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1
def _process_next_batch(self, batch):
self.rcvd_idx += 1
# 放下一个 (batch_idx, sample_indices)
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("DataLoaderIterator cannot be pickled")
def _shutdown_workers(self):
if not self.shutdown:
self.shutdown = True
self.done_event.set()
for _ in self.workers:
# shutdown 的时候, 会将一个 None 放到 index_queue 中
# 如果 _worker_loop 获得了这个 None, _worker_loop 将会跳出无限循环,将会结束运行
self.index_queue.put(None)
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
Dataset
、Dataloader
和DataLoaderIter
是层层封装的关系,最终在内部使用DataLoaderIter
进行迭代。
本人使用自己的数据集实现的一个DataSet类,如下所图所示:
class TrainsetLoader(Dataset):
def __init__(self, trainset_dir_hr, trainset_dir_lr,upscale_factor, patch_size, n_iters):
super(TrainsetLoader).__init__()
self.trainset_dir_hr = trainset_dir_hr
self.trainset_dir_lr = trainset_dir_lr
self.upscale_factor = upscale_factor
self.patch_size = patch_size
self.n_iters = n_iters
self.video_list = os.listdir(trainset_dir_hr)
def __getitem__(self, idx):
idx_video = random.randint(0, self.video_list.__len__()-1)
idx_frame = random.randint(1, 98) #1-98之间的随机数,包含左右边界
hr_dir = self.trainset_dir_hr + '/' + self.video_list[idx_video]
lr_dir = self.trainset_dir_lr + '/' + self.video_list[idx_video]
# read HR & LR frames
LR0 = Image.open(lr_dir + '/' + str(idx_frame-1).rjust(8,'0') + '.png')
LR1 = Image.open(lr_dir + '/' + str(idx_frame).rjust(8,'0') + '.png')
LR2 = Image.open(lr_dir + '/' + str(idx_frame+1).rjust(8,'0') + '.png')
HR0 = Image.open(hr_dir + '/' + str(idx_frame-1).rjust(8,'0')+ '.png')
HR1 = Image.open(hr_dir + '/' + str(idx_frame).rjust(8,'0') + '.png')
HR2 = Image.open(hr_dir + '/' + str(idx_frame+1).rjust(8,'0') + '.png')
LR0 = np.array(LR0, dtype=np.float32) / 255.0
LR1 = np.array(LR1, dtype=np.float32) / 255.0
LR2 = np.array(LR2, dtype=np.float32) / 255.0
HR0 = np.array(HR0, dtype=np.float32) / 255.0
HR1 = np.array(HR1, dtype=np.float32) / 255.0
HR2 = np.array(HR2, dtype=np.float32) / 255.0
# extract Y channel for LR inputs
HR0 = rgb2y(HR0)
HR1 = rgb2y(HR1)
HR2 = rgb2y(HR2)
LR0 = rgb2y(LR0)
LR1 = rgb2y(LR1)
LR2 = rgb2y(LR2)
# crop patchs randomly
HR0, HR1, HR2, LR0, LR1, LR2 = random_crop(HR0, HR1, HR2, LR0, LR1, LR2, self.patch_size, self.upscale_factor)
HR0 = HR0[:, :, np.newaxis]
HR1 = HR1[:, :, np.newaxis]
HR2 = HR2[:, :, np.newaxis]
LR0 = LR0[:, :, np.newaxis]
LR1 = LR1[:, :, np.newaxis]
LR2 = LR2[:, :, np.newaxis]
HR = np.concatenate((HR0, HR1, HR2), axis=2)
LR = np.concatenate((LR0, LR1, LR2), axis=2)
# data augmentation
LR, HR = augumentation()(LR, HR)
return toTensor(LR), toTensor(HR)
def __len__(self):
return self.n_iters
使用的方式如下:
train_set = TrainsetLoader(opt.trainset_dir_hr,opt.trainset_dir_lr, opt.upscale_factor, opt.patch_size, opt.n_iters)
train_loader = DataLoader(train_set, num_workers=4, batch_size=opt.batch_Size, shuffle=True)
for iteration,data in enumerate(train_loader,1):
print('..................................')
if cuda:
batch_lr,batch_hr = Variable(data[0]).cuda(gpus_list[1]),Variable(data[1]).cuda(gpus_list[1])
......