前言
本文介绍了classdataset的几个要点,由哪些部分组成,每个部分需要完成哪些事情,如何进行数据增强,如何实现自己设计的数据增强。然后,介绍了分布式训练的数据加载方式,数据读取的整个流程,当面对超大数据集时,内存不足的改进思路。
本文延续了以往的写作态度和风格,即便是自己知道的内容,也仍然在写之前看了很多的文章来保证内容的正确性和全面性,因此写得极累,耗费时间较长。若有读者看完后觉得有所帮助,文末可以赞赏一点。
文末扫描二维码关注公众号CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读,招聘信息发布。
(零) 概述
浮躁是人性的一个典型的弱点,很多人总擅长看别人分享的现成代码解读的文章,看起来学会了好多东西,实际上仍然不具备自己从零搭建一个pipeline的能力。
在公众号(CV技术指南)的交流群里(群内交流氛围不错,有需要的请关注公众号加群),常有不少人问到一些问题,根据这些问题明显能看出是对pipeline不了解,却已经在搞项目或论文了,很难想象如果基本的pipeline都不懂,如何分析代码问题所在?如何分析结果不正常的可能原因?遇到问题如何改?
Pytorch在这几年逐渐成为了学术上的主流框架,其具有简单易懂的特点。网上有很多pytorch的教程,如果是一个已经懂的人去看这些教程,确实pipeline的要素都写到了,感觉这教程挺不错的。但实际上更多地像是写给自己看的一个笔记,记录了pipeline要写哪些东西,却没有介绍要怎么写,为什么这么写,刚入门的小白看的时候容易云里雾里。
鉴于此,本教程尝试对于pytorch搭建一个完整pipeline写一个比较明确且易懂的说明。
本教程将介绍以下内容:
-
准备数据,自定义classdataset,分布式训练的数据加载方式,加载超大数据集的改进思路。
-
搭建模型与模型初始化。
-
编写训练过程,包括加载预训练模型、设置优化器、设置损失函数等。
-
可视化并保存训练过程。
-
编写推理函数。
(一)数据读取
classdataset的定义
先来看一个完整的classdataset
classdataset的几个要点:
-
classdataset类继承torch.utils.data.dataset。
-
classdataset的作用是将任意格式的数据,通过读取、预处理或数据增强后以tensor的形式输出。其中任意格式的数据可能是以文件夹名作为类别的形式、或以txt文件存储图片地址的形式、或视频、或十几帧图像作为一份样本的形式。而输出则指的是经过处理后的一个batch的tensor格式数据和对应标签。
-
classdataset主要有三个函数要完成:__init__函数、__getitem__ 函数和__len__函数。
__init__函数
init函数主要是完成两个静态变量的赋值。一个是用于存储所有数据路径的变量,变量的每个元素即为一份训练样本,(注:如果一份样本是十几帧图像,则变量每个元素存储的是这十几帧图像的路径),可以命名为self.filenames。一个是用于存储与数据路径变量一一对应的标签变量,可以命名为self.labels。
假如数据集的格式如下:
可通过per_classes = os.listdir(data_path) 获得所有类别的文件夹,在此处per_classes的每个元素即为对应的数据标签,通过for遍历per_classes即可获得每个类的标签,将其转换成int的tensor形式即可。在for下获得每个类下每张图片的路径,通过self.join获得每份样本的路径,通过append添加到self.filenames中。
__getitem__ 函数
getitem 函数主要是根据索引返回对应的数据。这个索引是在训练前通过dataloader切片获得的,这里先不管。它的参数默认是index,即每次传回在init函数中获得的所有样本中索引对应的数据和标签。因此,可通过下面两行代码找到对应的数据和标签。
获得数据后,进行数据预处理。数据预处理主要通过 torchvision.transforms 来完成,这里面已经包含了常用的预处理、数据增强方式。其完整使用方式在官网有详细介绍:https://pytorch.org/vision/stable/transforms.html
上面这里介绍了最常用的几种,主要就是resize,随机裁剪,翻转,归一化等。
最后通过transforms.Compose(transform_train_list)来执行。
除了这些已经有的数据增强方式外,在《
下面以随机擦除作为例子。
如上所示,自己写一个类RandomErasing,继承object,在call函数里完成你的操作。在transform_train_list里添加上RandomErasing的定义即可。
transform_train_list = [ transforms.Resize((self.opt.h, self.opt.w), interpolation=3), transforms.Pad(self.opt.pad, padding_mode='edge'), transforms.RandomCrop((self.opt.h, self.opt.w)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0]) #添加到这里 ]
__len__函数
len函数主要就是返回数据长度,即样本的总数量。前面介绍了self.filenames的每个元素即为每份样本的路径,因此,self.filename的长度就是样本的数量。通过return len(self.filenames)即可返回数据长度。
验证classdataset
分布式训练的数据加载方式
前面介绍的是单卡的数据加载,实际上分布式也是这样,但为了高速高效读取,每张卡上也会保存所有数据的信息,即self.filenames和self.labels的信息。只是在DistributedSampler 中会给每张卡分配互不交叉的索引,然后由torch.utils.data.DataLoader来加载。
dataset = My_Dataset(data_folder=data_folder) sampler = DistributedSampler(dataset) if is_distributed else None loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
数据读取的完整流程
结合上面这段代码,在这里,我们介绍以下读取数据的整个流程。
-
首先定义一个classdataset,在初始化函数里获得所有数据的信息。
-
classdataset中实现getitem函数,通过索引来获取对应的数据,然后对数据进行预处理和数据增强。
-
在模型训练前,初始化classdataset,通过Dataloader来加载数据,其加载方式是通过Dataloader中分配的索引,调用getitem函数来获取。
关于索引的分配,在单卡上,可通过设置shuffle=True来随机生成索引顺序;在多机多卡的分布式训练上,shuffle操作通过DistributedSampler来完成,因此shuffle与sampler只能有一个,另一个必须为None。
超大数据集的加载思路
问题所在
再回顾一下上面这个流程,前面提到所有数据信息在classdataset初始化部分都会保存在变量中,因此当面对超大数据集时,会出现内存不足的情况。
思路
将切片获取索引的步骤放到classdataset初始化的位置,此时每张卡都是保存不同的数据子集。通过这种方式,可以将内存用量减少到原来的world_size倍(world_size指卡的数量)。
参考代码
class RankDataset(Dataset): ''' 实际流程 获取rank和world_size 信息 -> 获取dataset长度 -> 根据dataset长度产生随机indices -> 给不同的rank 分配indices -> 根据这些indices产生metas ''' def __init__(self, meta_file, world_size, rank, seed): super(RankDataset, self).__init__() random.seed(seed) np.random.seed(seed) self.world_size = world_size self.rank = rank self.metas = self.parse(meta_file) def parse(self, meta_file): dataset_size = self.get_dataset_size(meta_file) # 获取metafile的行数 local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size) # 根据world size和rank,获取当前epoch,当前rank需要训练的index。 self.metas = self.read_file(meta_file, local_rank_index) def __getitem__(self, idx): return self.metas[idx] def __len__(self): return len(self.metas) ##train for epoch_num in range(epoch_num): dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num) sampler = RandomSampler(datset) dataloader = DataLoader( dataset=dataset, batch_size=32, shuffle=False, num_workers=4, sampler=sampler)
这一节参考链接:https://zhuanlan.zhihu.com/p/357809861
总结
本篇文章介绍了数据读取的完整流程,如何自定义classdataset,如何进行数据增强,自己设计的数据增强如何写,分布式训练是如何加载数据的,超大数据集的数据加载改进思路。
相信读完本文的读者对数据读取有了比较清晰的认识,下一篇将介绍搭建模型与模型初始化。
关注公众号可加计算机视觉交流群
欢迎关注公众号 CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读。
在公众号中回复关键字 “入门指南“可获取计算机视觉入门所有必备资料。
其它文章