#! https://zhuanlan.zhihu.com/p/543481537
本次笔记记录如何构建数据集,如何搭建minibatch,以及DataLoader源码剖析.
#使用Torchvision导入内置数据集
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
继承抽象类Dataset,并自己实现以下三个函数:
__init__
__len__
__getitem__
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])#每张图片路径
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)#预处理
if self.target_transform:
label = self.target_transform(label)#预处理label
return image, label
这种dataset叫做map-style datasets
读流式数据可以用iterable-style dataset
使用Minibatch的形式训练数据,在每个epoch都reshuffle数据从而降低模型过拟合的可能性,还通过python的multiprocessing多进程加载数据,使得读数据的过程不影响gpu训练,实现低延迟。(有的数据集读取的过程贼慢,比方说TextVQA这种,很影响GPU训练的效率)
DataLoader返回的是一个列表。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#test不需要梯度下降,只要前向传递,不需要shuffle
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
展示一张图片与标签,
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()#灰度图片要变成三通道
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
官网的教程过于简单,下面深入剖析DataLoader,
包括sampler,collate_fn处理等,还有一些排序方法(比如让一个minibatch中样本长度差不多,这样在进行padding等操作的时候就会简单很多)。
主要解析三部分代码:
首先看dataloader
'''
dataset: Dataset[T_co]:传入一个实例化的dataset对象
shuffle:每个epoch后打乱数据,一般在训练集上用
sampler/batch_sampler:自定义采样方式,显然与shuffle冲突
num_workers:进程数量,多进程读数据,0代表只用一个主进程
pin_memory:把tensor保存在gpu上,不用重复保存,对于效率影响有待考究
drop_last:样本数量不是batch_size整数倍时,把最后一部分数据丢掉
collate_fn:自定义聚集函数,对小批次数据进行再次处理,输入一个Batch,输出一个batch
timeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
'''
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader")
if num_workers < 0:
raise ValueError('num_workers option should be non-negative; '
'use num_workers=0 to disable multiprocessing.')
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if num_workers == 0 and prefetch_factor != 2:
raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
'let num_workers > 0 to enable multiprocessing.')
assert prefetch_factor > 0
if persistent_workers and num_workers == 0:
raise ValueError('persistent_workers option needs num_workers > 0')
#设置成员变量
self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
#后续代码太长,这里不一一展示,后面不算很难
后续代码太长,这里不一一展示,init函数主要做三件事:
这里说几点重要的,方便放矢地看源码:
1.构建sampler
1.1.shuffle内部通过RandomSampler的具体实现,重点看__iter__.
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
yield from torch.randperm(n, generator=generator).tolist()
1.2.shuffle = None通过SequentialSampler实现.
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
2.我们不使用batchsampler而仅仅是设定batch_size,内部其实还是通过BatchSampler实现的.
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
这里插入说一下python的迭代器和生成器,方便下面的理解
'''
迭代是Python最强大的功能之一,是访问集合元素的一种方式。
迭代器是一个可以记住遍历的位置的对象。
迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。
迭代器有两个基本的方法:iter() 和 next()。
字符串,列表或元组对象都可用于创建迭代器:
'''
>>> list=[1,2,3,4]
>>> it = iter(list) # 创建迭代器对象
>>> print (next(it)) # 输出迭代器的下一个元素
1
>>> print (next(it))
2
#迭代器对象可以使用常规for语句进行遍历:
list=[1,2,3,4]
it = iter(list) # 创建迭代器对象
for x in it:
print (x, end=" ")
'''
在 Python 中,使用了 yield 的函数被称为生成器(generator)。
跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。
在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。
调用一个生成器函数,返回的是一个迭代器对象。
以下实例使用 yield 实现斐波那契数列:
'''
import sys
def fibonacci(n): # 生成器函数 - 斐波那契
a, b, counter = 0, 1, 0
while True:
if (counter > n):
return
yield a
a, b = b, a + b
counter += 1
f = fibonacci(10) # f 是一个迭代器,由生成器返回生成
while True:
try:
print (next(f), end=" ")
except StopIteration:
sys.exit()
回归BatchSampler实现:
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
f len(batch) == self.batch_size:
yield batch #生成器,每次next返回一个batch
batch = []
if len(batch) > 0 and not self.drop_last:#drop_last的原理
yield batch
3.collate_fn自定义聚集函数
if collate_fn is None:
if self._auto_collation: #batch_sample is not None ->true
collate_fn = _utils.collate.default_collate #输入输出都是batch,做了一些数据类型的转换
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
4.迭代器iter()可以把dataloader变成一个迭代器
iter
-> _get_iterator()
_get_iterator()
-> _SingleProcessDataLoaderIter()
or _MultiProcessingDataLoaderIter()
_SingleProcessDataLoaderIter()
继承自BaseDataLoaderIter
.
我们发现_SingleProcessDataLoaderIter()
的_next_data
函数没有在子类中被调用,猜测可以在父类BaseDataLoaderIter
找到调用.
经查看,父类__next__
中确实调用了_next_data
,这就可以解释了代码的内部原理。