Pytorch将数据集和数据集的加载定义为两个单独对象,使数据集代码和模型训练代码相分离,以获得更好的可读性和模块化;Pytorch提供了两个DataSet和DataLoader两个类。
数据集对象抽象类,加载自定义的数据需要继承 DataSet(torch.utils.data)类。
Pytorch支持两种类型的DataSet:Map 类型 DataSet 和 Iterable 类型 DataSet。
需要实现 __getitem__() 和 __len__() 函数,表示从索引/键到数据样本的映射。数据集在使用时,通过索引直接获取相关样本数据。如 dataset[idx] 表示使用idx从磁盘上的文件夹中读取第idx 个图像及其相应的标签。
__len__(): 使用该函数返回数据集的大小;
__getitem__():接收一个index, 查找数据和标签,index 是一个 list 的index,list中的每个元素包含数据和标签,其只有在用到的时候,才将数据读入;
index的取值范围是根据__len__()的返回值确定的;
class MyDataSet(Dataset): def __init__(self, X, Y): self.X = X # 样本 self.Y = Y # label def __getitem__(self, idx): item = {key: torch.tensor(value[idx]) for key, value in self.X.items()} item['label'] = torch.tensor(int(self.Y[idx])) return item def __len__(self): return len(self.Y)
需要实现函数 __iter__(),对数据样本进行迭代访问,特别适用于随机读取代价高昂以及批量大小取决于获取的数据等场景。例如,从数据库、远程服务器甚至实时生成的日志中读取的数据流场景中,可使用 iter(dataset) 访问数据。
class MyIterableDataSet(IterableDataset): def __init__(self, file_path): self.file_path = file_path def __iter__(self): with open(self.file_path, 'r') as file_obj: for line in file_obj: line_data = line.strip('\n').split(',') yield line_data
迭代器,传入 Dataset 对象,按照 batch_size 取数据, 取出大小等同于batch_size的index列表,将列表中的 index 输入到 dataset 的getitem()函数中,取出该index对应的数据,堆叠每个index对应的数据,构成一个batch的数据。
参数含义:
dataset:定义的dataset类返回的结果;
batch_size:每个bacth要加载的样本数,默认为1;
shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False;
sampler:从数据集中加载数据采用的策略,如果指定,shuffle必须为False;
默认采用SequentialSampler,按顺序逐一采样,常用的有随机采样器:RandomSampler,当dataloader的 shuffle = True 时,系统会自动调用这个采样器,实现打乱数据;另外一个很有用的采样方法: WeightedRandomSampler,根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用来重采样;
from torch.utils.data.sampler import WeightedRandomSampler
# 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
weights = [2 if label == 1 else 1 for data, label in dataset]
sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
batch_sample:表示一次返回一个 batch 的 index;
num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程;
如果不使用默认值,会一次性创建num_workers个子线程,用batch_sampler将指定batch分配给指定worker,worker将对应 batch 加载进 RAM,dataloader 直接从RAM中找本轮迭代用的batch。如果num_worker 设置很大,
优点:寻batch速度快,因为下一轮迭代的batch可能在之前迭代时已经加载好了;
缺点:内存开销大,加重了CPU负担(worker加载数据到RAM的进程是进行CPU复制);
如果num_worker 为0,每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤,只有需要的时候再加载相应的batch,速度就慢了;num_workers 的经验设置值是电脑/服务器的CPU核心数,如果CPU很强、RAM很充足,可以设置得更大些,对于单机来说,单跑一个任务的话,直接设置为CPU的核心数最好;
collate_fn:表示合并样本列表以形成小批量的Tensor对象;在最后一步堆叠的时候可能会出现问题,如果一条数据中所含有的每个数据元的长度不同,将无法进行堆叠。此时需要先进行长度上补齐再堆叠。collate-fn() 就是手动将抽取样本堆叠的函数。
collate_fn 是作为参数传入 DataLoader的, 默认参数是 getitem 函数返回的数据项的 batch 形成的列表。collate_fn 可自定义取出一个batch数据的格式,该函数的输出就是对dataloader 遍历, 取出一个batch的数据。
传入额外参数的方法:使用lambda或创建可被调用类。
lambda:
info = args.info # info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))
创建可被调用类
class collater():
def __init__(self, *params):
self. params = params
def __call__(self, data):
'''在这里重写collate_fn函数'''
collate_fn = collater(*params)
loader = Dataloader(collate_fn=collate_fn)
pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,生成的Tensor数据是属于内存中的锁页内存区,将Tensor数据转义到GPU中速度就会快一些,默认为False;通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。
drop_last:整个数据长度不能够整除batch_size,是否要丢弃最后的batch,默认False;
参考:
Pytorh学习——DataSet和DataLoader_MatrixSpace001的博客-CSDN博客
pytorch -- 构建自己的Dateset,DataLoader如何使用_torch.utils.data_无脑敲代码,bug漫天飞的博客-CSDN博客
pytorch构造可迭代的Dataset——IterableDataset(pytorch Data学习二)_呆萌的代Ma的博客-CSDN博客
https://pytorch.org/docs/stable/data.html#