在PyTorch官方文档中提供了torchvision.transforms模块对图片数据进行变换,torch.utils.data.Dataset 和 torch.utils.data.DataLoader模块来读取数据。要实现自定义数据集,就要继承 torch.utils.data.Dataset
,并实现__getitem__()
和 __len__()
两个方法用于读取并处理数据,得到相对应的数据处理结果后。将自定义的Dataset
封装到 DataLoader
中,就能实现了单/多进程迭代输出数据。
在训练过程中,数据的处理基本包括如下:
torchvision.transforms
模块中的方法进行简单处理。当然也可以自己写方法去处理。MyDataSet
类,继承torch.utils.data.Dataset
类,并实现__getitem__()
和 __len__()
方法。DataLoader
数据加载器根据自定义数据集加载数据。其中可以使用默认的Sampler
和Collate Function
。可视化可参考:PyTorch DataLoader工作原理可视化
在torchvision.transforms模块中提供了一般的图像数据变换操作类,可以用于实现数据预处理(data preprocessing)和数据增广(data argumentation)。这里列举一些常用的变换操作。
这可以看作是一种容器,能够将多种数据变换进行组合。输入是对载入数据的各种变换操作集列表。
transformer = transforms.Compose([
transforms.Resize(224,224),
transforms.transforms.RandomResizedCrop((224), scale = (0.5,1.0)),
transforms.RandomHorizontalFlip(),
])
# 对图片img进行变换操作
img_trans = transformer(img)
标准正态分布对数据进行标准化,其中mean是均值,std是标准差,变换完成后数据符合均值为0,标准差为1的标准正态分布。对于RGB三通道图,mean和std可以是三维的。
normalize_transformer = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
对载入的图片数据进行缩放,其中,size可以是整数类型(将长宽中最短的缩放到size,然后长的等比缩放),也可以是(h,w)
的序列。
中心裁剪,以输入图的中心点为中心点为参考点,按需要的大小进行裁剪。如果输入的是一个整型数据,那么裁剪的长和宽都是这个数值。当然也可以是(h,w)
类型
当然对于图片的裁剪操作布置这些,还有:
- transforms.RandomCrop(size):随机裁剪
- transforms.RandomResizedCrop(size,scale):先将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为size的大小。
- transforms.RandomHorizontalFlip:用于对载入的图片按随机概率进行水平翻转
- transforms.RandomVerticalFlip:用于对载入的图片按随机概率进行垂直翻转。
- transforms.RandomRotation:按照degree随机旋转一定角度
用于对载入的图片数据进行类型转换,将之前构成PIL图片的数据转换成Tensor数据类型,让PyTorch能够对其进行计算和处理。
用于将Tensor变量的数据转换成PIL图片数据,主要是为了方便图片内容的显示。
在官方介绍中,如果要使用自定义数据集,需要继承torch.utils.data
类,并实现__getitem__()
和 __len__()
两个方法:
__len__
返回的是数据集的大小
__getitem__
实现通过索引获取数据集中的某一个数据,以[input,label]
的形式给出。
在包torch.utils.data
中,包含pytorch
内部默认的数据处理类:
Dataset(object)
IterableDataset(Dataset)
TensorDataset(Dataset)
: 封装成tensor的数据集,每一个样本都通过索引张量来获得。ConcatDataset(Dataset)
: 连接不同的数据集以构成更大的新数据集Subset(Dataset)
: 获取指定一个索引序列对应的子数据集ChainDataset(IterableDataset)
而在torchvision
也封装了几种常见的数据集,在torchvision.datasets
中,包括:FashionMNIST
, ImageFolder
, CIFAR10
, CIFAR100
, SVHN
, PhotoTour
, ImageNet
, CocoDetection
等。
这里对torch
中Dataset类
,TensorDataset类
,torchvision
中ImageFolder类
,FashionMNIST类
进行分析,其继承关系如下图所示:
这里有个问题,就是在看源码的时候
torchvision.datasets.Dataset
中没有__len__()
方法,而是在后面类中定义的这个方法。但是官网中说是自定义数据集需要实现两种方法__getitem__()
和__len__()
方法。如果是直接继承torch.utils.data.Dataset
类的,比如TensorDataset
数据集继承的方法中是不包括__len__()
的。
可以看出,所有的实现类基本都是直接或者间接继承于torch.utils.data.Dataset
这个类的。基于此,编写自定义数据集类:
创建my_dataset文件,内容如下:
import torch
from torch.utils.data import Dataset
import numpy as np
# 自定义数据集,继承torch.utils.data.Dataset
class MyDataSet(Dataset):
# 初始化函数,得到数据,这里不绝对,
def __init__(self, pathData, pathLabel):
self.data = np.load(pathData) # 传入dataset 特征的路径
self.label = np.load(pathLabel) # 传入dataset 中label的路径
# 该函数返回数据大小长度,目的是方便DataLoader划分。
def __len__(self):
return len(self.data)
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 表示静态方法,该方法不一定需要,只是在dataloader中方便使用而已
@staticmethod
def collate_fn(batch):
"""
该方法用于DataLoader中的collate_fn参数。到时候可以直接使用 对象.collate_fn,或者 类.collate_fn。
该方法是在Dataloader中重新整理数据的方法。对该batch中的数据进行重新整理。如果没有定义,则会使用默认的collate_fn
:param batch:
:return:
"""
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
那么到时候实例化自定义数据集的时候可以通过:
pathX = '' # 特征数据的文件地址
pathY = '' # 标签数据的文件地址
torch_data = MyDataSet(pathX,pathY)# 实例化数据集dataset
torch.utils.data.DataLoader(dataset, batch_size, shuffle, drop_lase, num_workers, collate_fn, sampler)
参数含义:
torch.utils.data.Dataset
对象数据或者子类对象。通过Loader操作得到的数据可以通过迭代器进行输出数据。如下:
datas = DataLoader(torch_dataset, batch_size, shuffle=True, num_workers=0)
for i, data in enumerate(datas):
# 这里的i表示第几个batch的数据,而data表示该batch对应的数据,包含训练数据和标签
print("{}个batch \n {}".format(i, data))
# 通过这种方式获取第i个batch数据中训练数据和训练样本
images, labels = data
pytorch采样器有如下几个(
torch.utils.data
包中):
Sampler
SequentialSampler
: 顺序采样样本,始终按照同一个顺序。RandomSampler
: 无放回地随机采样样本元素。SubsetRandomSampler
: 无放回地按照给定的索引列表采样样本元素WeightedRandomSampler
: # 按照给定的概率来采样样本。BatchSampler
: # 在一个batch中封装一个其他的采样器。DistributedSampler
: 在包torch.utils.data.distributed
中,采样器可以约束数据加载进数据集的子集。
其继承关系如图所示:
Sampler类是所有的采样器的基类,每一个继承自Sampler的子类都必须实现它的__iter__()
方法和__len__
()方法。
__iter__()
实现如何迭代样本__len__()
返回一共有多少个样本对于默认使用的采样器,其实现源码如下:
if batch_sampler is None: # 没有手动传入batch_sampler参数时
if sampler is None: # 没有手动传入sampler参数时
if shuffle:
sampler = RandomSampler(dataset) # 随机采样
else:
sampler = SequentialSampler(dataset) # 顺序采样
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
网上找到的一个视频处理采样器:
class RandomSequenceSampler(Sampler):
# 作用与BatchSampler有点类似,每seq_len个视频shuffle
def __init__(self, n_sample, seq_len):
self.n_sample = n_sample # 视频的数量
self.seq_len = seq_len # 视频序列长度
def _pad_ind(self, ind):
zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
ind = np.concatenate((ind, zeros))
return ind
def __iter__(self):
idx = np.arange(self.n_sample)
if self.n_sample % self.seq_len != 0:
idx = self._pad_ind(idx)
idx = np.reshape(idx, (-1, self.seq_len))
np.random.shuffle(idx)
idx = np.reshape(idx, (-1))
return iter(idx.astype(int))
def __len__(self):
return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)
在继承Dataset类的自定义类中,__getitem__()
方法一般返回一组类似于[input,label]
的一个样本,而在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。
如果在DataLoader中不设置
collate_fn
,则会使用默认的函数default_collate(batch)
,在该方法中的有self.dataset[i] for i in indices
其中,indices是该batch_size中从Dataset子类中获取的索引集合,而self.dataset[i]就是Dataset子类中
__getitem__()
返回的结果。默认的函数default_collate(batch) 只能对大小相同的batch_size个input进行整理,
将
[(input0, label0), (input1, label1),(input2, label2), ]
整理成([input0,input1,input2,], [label0,label1,label2,])
, 这里要求多个input的大小要相同,如果不相同时候需要使用自定义函数callate_fn来处理。对于目标检测,其输入一般是
(input,box,label)
形式,这种也需要自定义,因为默认函数只能处理(input,label)
格式。
简单的collate_fn函数参考:
函数定义形式:
def collate_fn(self, batch):
for unit in batch:
unit_x.append(unit[0])
unit_y.append(unit[1])
...
return {x: torch.tensor(unit_x), y: torch.tensor(unit_y)}
# 使用,直接将函数名传进去就好
loader = Dataloader(collate_fn=collate_fn)
说明,这里的batch是该batch_size中的数据集合
函数输入形式:
[(input0, label0), (input1, label1),(input2, label2),...]
函数输出形式:
([input0,input1,input2,...], [label0,label1,label2,...])
创建可被调用的类的形式:
class collater():
def __init__(self, *params):
self. params = params
def __call__(self, data):
'''在这里重写collate_fn函数'''
# 对于类的形式,使用的时候是,创建对象作为输入即可
collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)
对于目标检测的自定义collate_fn函数参考如下:
def collate_fn(self, batch):
paths, imgs, targets = list(zip(*batch))
# Remove empty placeholder targets
# 有可能__getitem__返回的图像是None, 所以需要过滤掉
targets = [boxes for boxes in targets if boxes is not None]
# Add sample index to targets
# boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0)
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
# 每个图像大小不同呢,所以resize到统一大小
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
self.batch_count += 1
return paths, imgs, targets