教科书 Pytorch入门与实践第五章:pytorch-book/chapter5.ipynb at master · chenyuntc/pytorch-book · GitHub
1、TORCH.UTILS.DATA 官网地址:torch.utils.data — PyTorch 1.11.0 documentation
在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为Dataset
类,实现自定义的数据集需要继承Dataset,并实现两个Python魔法方法:
__getitem__
:返回一条数据,或一个样本。obj[index]
等价于obj.__getitem__(index)
__len__
:返回样本的数量。len(obj)
等价于obj.__len__()
其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder.
创建Dataset例子:
2. class torch.utils.data.sampler.Sampler(data_source)
参数: data_source (Dataset) – dataset to sample from
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度。
3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
参数:
* dataset (Dataset): 加载数据的数据集
* batch_size (int, optional): 每批加载多少个样本
* shuffle (bool, optional): 设置为“真”时,在每个epoch对数据打乱.(默认:False)
* sampler (Sampler, optional): 定义从数据集中提取样本的策略,返回一个样本
* batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time 返回一批样本. 与atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional): 用于加载数据的子进程数。0表示数据将在主进程中加载。(默认:0)
* collate_fn (callable, optional): 合并样本列表以形成一个 mini-batch. # callable可调用对象
* pin_memory (bool, optional): 如果为 True, 数据加载器会将张量复制到 CUDA 固定内存中,然后再返回它们.pin memory中的数据转到GPU会快一些
* drop_last (bool, optional): 设定为 True 如果数据集大小不能被批量大小整除的时候, 将丢掉最后一个不完整的batch,(默认:False).
* timeout (numeric, optional): 如果为正值,则为从工作人员收集批次的超时值。应始终是非负的。(默认:0)
* worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None).
* generator (torch.Generator, optional) – If not None
, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None
)如果此选项不是None,RandomSampler 将使用RNG去生成随机的索引,在多进程中也会生成每个进程基准种子)
* prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker. 2
means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2
)对每个进程来说,需要提前装载的样本数量。当此值为2时,对所有的进程来数,需要提前获取的样本数量为2*num_workers,默认值为2)
* persistent_workers (bool, optional) – If True
, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False
)如果设置为True,对每个进程来数,当一个数据集被使用一次后,此进程并不会被关闭,这样就会保持进程中的数据集实例是活的。默认值是False。
代码例子:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True,
num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
输出:torch.Size([3, 3, 224, 224])
4、Dataset Types:DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。torch支持两种类型的数据集。
(1)map-style 类型。一个map-style类型是实现了__getitem__() 和__len__()协议的类,它代表了一个从索引/键值 到数据样本的映射。例如,对于一个通过dataset[idx]访问的数据集,可以读到第idx个图片,并从磁盘的文件中取到对应的标签。看下面例子吧
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transform = T.Compose([
T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
T.CenterCrop(224), # 从图片中间切出224*224的图片
T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms=transforms
def __getitem__(self, index):
img_path = self.imgs[index]
label = 0 if 'dog' in img_path.split('/')[-1] else 1
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
print(img.size(), label)
输出:
torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 0 torch.Size([3, 224, 224]) 1
(2)Iterable-style 类型。这种类型的父类是IterableDataset,并且实现了 __iter__() 协议,代表了在数据样本的迭代器。这种类型非常适合随机读取非常难或者不可能的情况,这种情况下批的大小取决于得到的数据。例如有这样一个数据集,可以调用iter(dataset),可以从数据库、远程服务器或者实时产生的日志获取样本流。
注意:当使用多进程加载Iterable-style 类型的DataLoader时,一份样本会被复制至所有的进程中,这份复制的样本将被差异配置以避免重复数据。请参阅IIterableDataset文档以如何实现它。
dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它,例如:
for batch_datas, batch_labels in dataloader:
train()
或
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)
5、PyTorch中还单独提供了一个sampler
模块,用来对数据进行采样。常用的有随机采样器:RandomSampler
,当dataloader的shuffle
参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler
,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler
,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。
语法:CLASStorch.utils.data.
Sampler
(data_source)
功能:所有sampler的基类。所有子类都必须重写__iter__()
方法,提供一种在dataset elements的indices/keys上的迭代方法。选择性重写__len__()
方法,返回这个迭代器的长度。
参数:
Dataset
):dataset to sample from.语法:CLASStorch.utils.data.
SequentialSampler
(data_source)
功能:顺序采样。
语法:CLASS torch.utils.data.
RandomSampler
(data_source, replacement=False, num_samples=None, generator=None)
功能:随机采样。
语法:CLASS torch.utils.data.
WeightedRandomSampler
(weights, num_samples, replacement=True, generator=None)
功能:权重采样。构建WeightedRandomSampler
时需提供两个参数:每个样本的权重weights
、共选取的样本总数num_samples
,以及一个可选参数replacement
。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement
用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights
参数失效。
代码:
dataset = DogCat('data/dogcat/', transforms=transform)
# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights
from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
num_samples=9,\
replacement=True)
dataloader = DataLoader(dataset,
batch_size=3,
sampler=sampler)
for datas, labels in dataloader:
print(labels.tolist())
输出:
[1, 2, 2, 1, 2, 1, 1, 2]
[1, 0, 1] [0, 0, 1] [1, 0, 1]