Pytorch数据载入函数介绍

Pytorch02——数据载入

参考:https://pytorch.org/docs/stable/data.html

本文主要是对Pytorch数据载入方式的官方文档翻译,以及梳理和总结。有错误的地方请诸位大佬指正!转载请注明来源!

主要涉及Python API中的torch.utils.data,Libraries库中的torchvision.datasets及torchvision.transforms。

目录

Pytorch02——数据载入

torch.utils.data.DataLoader

torchvision.datasets

torchvision.transforms

总结


torch.utils.data.DataLoader

torch.utils.data.DataLoader作为Pytorch数据载入工具的核心,代表了一种使用python迭代式载入数据的方式,并支持:

  1. 映射式和迭代式的数据集
  2. 自定义数据载入顺序
  3. 自动批处理
  4. 单线程或多线程的数据加载
  5. 自动内存锁页

这些选项可由DataLoader构造器的参数指定:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

下面详细描述这些参数的使用方法和效果。

  • Dataset Types(数据集类型)

dataset是DataLoader最重要的参数,制定了要加载的数据集对象类型。Pytorch支持两种不同类型的数据集:映射式数据集(map-style)、迭代式数据集(iterable-style)。

映射式数据集包含了__getitem__()和__len__()方法,使用了类似于字典格式的键值对映射关系,将索引(可能是非整型数据)映射到数据样本。比如通过dataset[idx]可以访问到第idx个图像以及其对应的标签。更多介绍在torch.utils.data.Dataset。

迭代式数据集是IterableDataset的一个实例,包含了__iter__()方法,类似于数据集的迭代器。适用于随机读取成本高甚至不可能的情况(内存很小),一个批处理的大小(batch size)取决于能够获取的数据。比如,以数据流的形式从远程服务器读取数据集,或者实时写入事件日志。更多介绍在torch.utils.data.IterableDataset。

注意,当使用IterableDataset进行多线程数据载入时,每个线程载入的数据集对象是相同的,会造成数据集重复,必须进行不同的配置。配置方法见IterableDataset文档。

  • Data Loading Order and Sampler(数据集载入顺序及采样)

对于迭代式的数据集,数据载入顺序完全由用户定义的迭代器控制,可以轻松实现块读取和动态设置批处理大小(使用yield语句控制一个batch size)。

对于映射式的数据集,torch.utils.data.Sampler可以用于指定数据集索引的顺序,是数据集索引的迭代器。比如,使用随机梯度下降(SGD),Sampler可以随机的置换数据集索引列表,并且使用yield语句每次输出一个索引,或者使用yield语句每次输出一小批索引用于实现min-batach SGD。

shuffle参数的值,可以控制顺序载入数据(False),还是随机载入数据(True)。此外,用户也可以使用sampler参数指明自定义的Sampler对象,该对象每次都会产生要提取的下一个索引。

自定义的Sampler每次产生的索引列表可以传递到batch_sampler参数,自动批处理也可以通过设置batch_size及drop_last开启。

注意:sampler和batch_sampler都和迭代式的数据集不兼容,因为迭代式的数据集没有索引。

  • Loading Batched and Non-Batched Data(载入批处理和非批处理的数据集)

通过batch_size,drop_last,batch_sampler三个参数的指定,DataLoader可以自动将独立获取的数据样本整理成批处理格式。

默认自动批处理,因为这是最常见的方式,一般批处理后向量的第一维是批处理序号。batch_size默认为1,而不是None,batch_size和drop_last参数可以用户指定如何获取一个批处理中的数据集索引。对于映射式的数据集,用户可以选择指定batch_sampler,每次产生一个数据集索引列表。

注意,batch_size,drop_last本质上是用于从sampler构建一个batch_sampler。对于映射式的数据集,sampler由用户构建或者基于shuffle参数构建。对于迭代式数据集,sampler是一个虚拟的无限数据集。当使用多线程载入迭代式的数据集时,drop_last参数可以丢弃每个线程最后一个非完整批次。

当使用sampler得到样本的索引列表后,collate_fn参数将索引对应的数据整理为一个批次。这种情况下,从映射式的数据集载入可以表示为:

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

从迭代式的数据集载入可以表示为:

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

自定义的collate_fn可以用于自定义数据集整理,比如批次不完整时,自动复制补全。

禁用自动批处理,在某些情况下,用户可能希望手动进行批处理,或仅加载单个样本。例如,直接加载批处理过的数据(例如,从数据集中批量读取或读取连续的内存块,或者批处理大小取决于数据,或者程序是为了处理单个样本设计的。在这些情况下,最好不要使用自动批处理(collate_fn用于整理样本),而应让数据加载器直接返回dataset对象的每个成员。

当batch_size和batch_sampler同时为None时,即可禁用自动批处理。每个样本将会经过collate_fn处理。当禁用自动批处理时,默认的collate_fn将Numpy数组转换为Tensor,其他保持不变。这种情况下,从映射式的数据集载入可以表示为:

for index in sampler:
    yield collate_fn(dataset[index])

从迭代式的数据集载入可以表示为:

for data in iter(dataset):
    yield collate_fn(data)

 使用collate_fn,collate_fn在自动批处理是否禁用的不同状态下,使用方法略有不同。

禁用自动批处理时,在处理每个单独的数据样本时会调用collate_fn,输出由数据加载迭代器产生,默认的collate_fn将numpy转换为tensor。

自动批处理时,在处理一批数据样本时会调用collate_fn,会将输入的样本整理为一个batch然后从数据加载迭代器输出。比如,每个样本包含了一个3通道的图片和它对应的整型标签,数据集的每个样本为元组(image, class_index),默认的collate_fn把这样元组构成的列表,整理为一个图像元组构成的Tensor和一个标签构成的Tensor。默认的collate_fn有以下特性:

  1. 总是在tensor前添加一个维度,表示batch的索引;
  2. 自动将numpy数组转化为tensor;
  3. 保留数据结构,比如,样本是dictionary格式,则输出为相同的key构成的tensor(如果不能转为tensor,则转为list)。
  • Single- and Multi-process Data Loading(单进程和多进程的数据加载)

默认使用单进程加载,通过num_workers参数可以指定加载的进程数量。python的全局解释器锁(GIL)不允许真正地完全并行化python代码。

单进程数据加载,数据加载会阻止计算,当用于进程间共享数据资源有限时,或数据集很小可以完全载入到内存中,建议使用单进程加载。单进程加载会显示更多的错误信息,有利于调试。

多进程数据加载,将num_workers设置为正整数,可以开启多进程加载。此时,会创建一个DataLoader的迭代器(比如,当你调用enumerate(dataloader)),以及指定进程数量的进程。dataset、collate_fn、worker_init_fn被传递给每个进程,用于初始化和获取数据。数据集将同时访问它内部的IO,每个进程都会执行transform。

torch.utils.data.get_worker_info()返回进程的信息(包括进程ID,数据集,初始化种子等等),在主进程中返回None。用户可以使用torch.utils.data.get_worker_info()或worker_init_fn单独配置每个数据集复制品,并且决定代码是否在每个进程运行。

对于映射式的数据集,主进程使用sampler产生索引,将索引传递给进程。也就是,所有的随机处理在主进程中完成。

对于迭代式的数据集,由于每个进程得到的是数据集的复制品,不进行任何设置会导致数据重复。使用torch.utils.data.get_worker_info()和worker_init_fn,用户可以分别配置每个进程。同样的,多进程加载,drop_last参数可以丢弃每个进程最后一个不完整batch。

由于workers依赖于python的multiprocessing库,所以在windows和unix系统的表现不同。

  • Memory Pinning(内存锁页)

当使用锁页内存向GPU拷贝数据时,速度更快。对于数据加载,将pin_memory设置为True,会自动将数据的tensor放在锁页内存中,能更快的将数据传输至GPU。

默认的锁页内存,逻辑上只能识别tensor以及tensor的映射和迭代。如果使用自定义的batch类型,锁页内存将无法识别,同时返回没有锁页内存的batch。如果想要使用锁页内存加载自定义的batch类型,需要重定义pin_memory()方法。

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

torch.utils.data.Dataset

torch.utils.data.IterableDataset

torch.utils.data.TensorDataset

torch.utils.data.ConcatDataset

torch.utils.data.ChainDataset

torch.utils.data.Subset

torch.utils.data.Sampler

torch.utils.data.SequentialSampler

torch.utils.data.RandomSampler

torch.utils.data.SubsetRandomSampler

torch.utils.data.WeightedRandomSampler

torch.utils.data.BatchSampler

torch.utils.data.distributed.DistributedSampler

torchvision.datasets

所有的torchvision.datasets都是torch.utils.data.Dataset的子类,都有__getitem__和__len__方法。因此,他们都可以传递给torch.utils.data.DataLoader。

自定义数据集主要需要以下两个类DatasetFolder和ImageFolder:

DatasetFolder主要用于通用数据的加载,比如文本、图像等数据。

ImageFolder主要用于通用图像的加载,比如jpg、png等不同格式的图片。

torchvision.datasets.DatasetFolder(root: str, loader: Callable[str, Any], extensions: Optional[Tuple[str, ...]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[str, bool]] = None)

Parameters:
root (string) – Root directory path.

loader (callable) – A function to load a sample given its path.

extensions (tuple[string]) – A list of allowed extensions. both extensions and is_valid_file should not be passed.

transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version. E.g, transforms.RandomCrop for images.

target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

is_valid_file – A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed.

Introduction:
A generic data loader where the samples are arranged in this way:

root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext

root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
torchvision.datasets.ImageFolder(root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[str, Any] = , is_valid_file: Optional[Callable[str, bool]] = None)

Parameters:
root (string) – Root directory path.

transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

loader (callable, optional) – A function to load an image given its path.

is_valid_file – A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files)

Introduction:
A generic data loader where the images are arranged in this way:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

torchvision.transforms

torchvision.transforms是公共的图片转换类,可以通过Compose连接在一起,此外,torchvision.transforms.functional模块可以提供更多的转换控制。所有的tranformation可以接受PIL图片,Tensor图片,或者以batch为单位的Tensor图片。Tensor图片格式为(C, H, W)(C for Channerl, H for height, W for width),以batch为单位的Tensor图片格式为(B, C, H, W)(B for batch)。确定性的或者随机的transformation会相同的应用在batch中的每个图片。

警告:因为v0.8.0所有的随机transformation使用torch默认的随机生成器产生随机参数,且这是向后兼容的,用户应该使用如下代码设置随机状态:

# Previous versions
# import random
# random.seed(12)

# Now
import torch
torch.manual_seed(17)

注意,即使使用相同的随机种子,torch的随机生成器和python的随机生成器也会产生不同的结果。

  • scriptable transforms

为了脚本化转化,请使用torchnn.Sequential而不是Compose。同时,请确保只是用脚本化的转化,比如:不要求lambda函数的torch.Tensor对象或者PIL.Image。任何自定义的转化,可以使用torch.jit.script,且应该派生自torch.nn.Module。

transforms = torch.nn.Sequential(
    transforms.CenterCrop(10),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

 

  • compositions of transforms

将几个transform组合在一起,不支持脚本化。

torchvision.transforms.Compose(transforms)

# example
transforms.Compose([
     transforms.CenterCrop(10),
     transforms.ToTensor(),
 ])

如果要进行脚本化,请使用torch.nn.Sequential。

>>> transforms = torch.nn.Sequential(
>>>     transforms.CenterCrop(10),
>>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
  • transforms on PIL Image and torch.*Tensor
# 中心剪裁
torchvision.transforms.CenterCrop(size)

# 将图片裁剪为四个角和一个中心
torchvision.transforms.FiveCrop(size)

# 将图片及其翻转图片剪裁为4个角+1个中心
torchvision.transforms.TenCrop(size, vertical_flip=False)

# 随机剪裁
torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

# 随机变换尺寸并裁剪
torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

# 随机缩放
torchvision.transforms.Resize(size, interpolation=2)

# 随机改变亮度,对比度和饱和度
torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

# 随机旋转
torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None, fill=None)

# 高斯模糊
torchvision.transforms.GaussianBlur(kernel_size, sigma=(0.1, 2.0))

# 灰度化
torchvision.transforms.Grayscale(num_output_channels=1)

# 随机灰度
torchvision.transforms.RandomGrayscale(p=0.1)

# 每条边填补至指定的像素点数目
torchvision.transforms.Pad(padding, fill=0, padding_mode='constant')

# 随机仿射变换
torchvision.transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, 

# 随机透视变换
torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.5, interpolation=2, fill=0)

# 随机水平翻转
torchvision.transforms.RandomHorizontalFlip(p=0.5)

# 随机竖直翻转
torchvision.transforms.RandomVerticalFlip(p=0.5)

# 随机选择变换
torchvision.transforms.RandomApply(transforms, p=0.5)
  • transforms on PIL Image only
# 随机选择一种变换方式,不支持torchscript
torchvision.transforms.RandomChoice(transforms)

# 随机选择变换方式的顺序,不支持torchscript
torchvision.transforms.RandomOrder(transforms)
  • transforms on torch.*Tensor only
# 
torchvision.transforms.LinearTransformation(transformation_matrix, mean_vector)

# 使用给定的均值和标准差进行归一化
torchvision.transforms.Normalize(mean, std, inplace=False)

# 随机擦除矩形区域
torchvision.transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)

# 转变图像数据类型
torchvision.transforms.ConvertImageDtype(dtype: torch.dtype)
  • conversion transform
# tensor转为PIL图像
torchvision.transforms.ToPILImage(mode=None)

# PIL图像转为tensor
torchvision.transforms.ToTensor
  • generic transform
# 根据用户定义的函数进行变换,不支持torchscript
torchvision.transforms.Lambda(lambd)
  • functional transform
# 使用functional transforms可以更精准的控制变换,但是不会随机生成参数,需要自行明确定义
import torchvision.transforms.functional as TF

# 调整亮度
torchvision.transforms.functional.adjust_brightness(img: torch.Tensor, brightness_factor: float) → torch.Tensor

# 调整对比度
torchvision.transforms.functional.adjust_contrast(img: torch.Tensor, contrast_factor: float) → torch.Tensor

# 调整gamma值
torchvision.transforms.functional.adjust_gamma(img: torch.Tensor, gamma: float, gain: float = 1) → torch.Tensor

# 调整色度
torchvision.transforms.functional.adjust_hue(img: torch.Tensor, hue_factor: float) → torch.Tensor

# 调整饱和度
torchvision.transforms.functional.adjust_saturation(img: torch.Tensor, saturation_factor: float) → torch.Tensor

# 设置剪裁
torchvision.transforms.functional.crop(img: torch.Tensor, top: int, left: int, height: int, width: int) → torch.Tensor

# 设置中心剪裁
torchvision.transforms.functional.center_crop(img: torch.Tensor, output_size: List[int]) → torch.Tensor

# 设置five剪裁
torchvision.transforms.functional.five_crop(img: torch.Tensor, size: List[int]) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

# 设置ten剪裁
torchvision.transforms.functional.ten_crop(img: torch.Tensor, size: List[int], vertical_flip: bool = False) → List[torch.Tensor]

# 设置缩放剪裁
torchvision.transforms.functional.resized_crop(img: torch.Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = 2) → torch.Tensor

# 设置缩放
torchvision.transforms.functional.resize(img: torch.Tensor, size: List[int], interpolation: int = 2) → torch.Tensor


# 设置擦除
torchvision.transforms.functional.erase(img: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) → torch.Tensor


# 设置高斯模糊
torchvision.transforms.functional.gaussian_blur(img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) → torch.Tensor

# 设置水平翻转
torchvision.transforms.functional.hflip(img: torch.Tensor) → torch.Tensor

# 设置竖直翻转
torchvision.transforms.functional.vflip(img: torch.Tensor) → torch.Tensor

# 设置归一化
torchvision.transforms.functional.normalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) → torch.Tensor

# 设置补全
torchvision.transforms.functional.pad(img: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = 'constant') → torch.Tensor

# 设置透射变换
torchvision.transforms.functional.perspective(img: torch.Tensor, startpoints: List[List[int]], endpoints: List[List[int]], interpolation: int = 2, fill: Optional[int] = None) → torch.Tensor

# 设置仿射变换
torchvision.transforms.functional.affine(img: torch.Tensor, angle: float, translate: List[int], scale: float, shear: List[float], resample: int = 0, fillcolor: Optional[int] = None) → torch.Tensor

# 设置旋转
torchvision.transforms.functional.rotate(img: torch.Tensor, angle: float, resample: int = 0, expand: bool = False, center: Optional[List[int]] = None, fill: Optional[int] = None) → torch.Tensor

# RGB转灰度
torchvision.transforms.functional.rgb_to_grayscale(img: torch.Tensor, num_output_channels: int = 1) → torch.Tensor

# 设置转灰度
torchvision.transforms.functional.to_grayscale(img, num_output_channels=1)

# 设置数据转换
torchvision.transforms.functional.convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) → torch.Tensor

# PIL转为tensor
torchvision.transforms.functional.pil_to_tensor(pic)

# tensor转PIL
torchvision.transforms.functional.to_pil_image(pic, mode=None)

# PIL转tensor
torchvision.transforms.functional.to_tensor(pic).

 

总结

仿射变换Affine和透视变换Perspective有什么区别?

仿射变换:参考:https://www.cnblogs.com/happystudyeveryday/p/10547316.html

Affine Transformation是一种二维坐标到二维坐标之间的线性变换,使用一个矩阵表示图像的平移、缩放、旋转、切变、对称变换及它们的线性组合。

\begin{pmatrix} x\\ y \end{pmatrix} = \begin{pmatrix} a_{11} & a_{12}& a_{13}\\ a_{21} & a_{22}& a_{23} \end{pmatrix} \begin{pmatrix} u\\ v\\ 1 \end{pmatrix}

Pytorch数据载入函数介绍_第1张图片

透视变换:参考:https://zhuanlan.zhihu.com/p/36082864

Perspeactive是三维空间的变换。

\begin{pmatrix} x\\ y\\z \end{pmatrix} = \begin{pmatrix} a_{11} & a_{12}& a_{13}\\ a_{21} & a_{22}& a_{23}\\a_{31} & a_{32}& a_{33} \end{pmatrix} \begin{pmatrix} u\\ v\\ 1 \end{pmatrix}

从另一个角度也能说明三维变换和二维变换的意思,仿射变换的方程组有6个未知数,所以要求解就需要找到3组映射点,三个点刚好确定一个平面。透视变换的方程组有8个未知数,所以要求解就需要找到4组映射点,四个点就刚好确定了一个三维空间。

 

线性变换LinearTransformation和归一化Normalization有什么区别?

参考:https://zhuanlan.zhihu.com/p/33173246

白化(whitening)”是一个重要的数据预处理步骤。白化一般包含两个目的:

(1)去除特征之间的相关性 —> 独立;

(2)使得所有特征具有相同的均值和方差 —> 同分布。

LinearTransformation官方文档举例可以用作白化,我还没完全搞懂,暂时留白,日后搞清楚了再补上。也欢迎各位大佬在评论区补充。

归一化是为了将不同特征维度的数据强行变为均值为0,方差为1的分布,这样在梯度更新时,更容易收敛。各种归一化的方式在参考的知乎文章中讲的很好,就不再赘述了。

 

除了我们经常使用的归一化处理,当我们手头掌握的训练数据很少时,可以选择使用transform中的平移、旋转、裁剪、缩放等等简单的方法,增加数据量,或者使用GAN等生成式的网络产生新的数据。

常用的数据集加载语法:

# 指明转换方式,一般常用的是归一化处理
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((mean for each channel), (standard_deviation for each channel), inplace=False)
    ])
# 将数据根据transform进行处理
dataset = torchvision.datasets.ImageNet('root_directory', train/val, download=True, transform)
# 将数据划分为batch,并载入内存
dataloader = torch.utils.data.DataLoader(dataset, batch_size=,shuffle=,num_workers=,drop_last=)

 

你可能感兴趣的:(Pytorch,人工智能,pytorch)