读《python计算机视觉与深度学习实战》(郭卡,戴亮编著)笔记·part1

1.Dataset

pytorch中提供了两种Dataset,一种是Dataset,另一种是IterableDataset。

在构建Dataset子类的时候,一般来说只需要定义__init__、__get_item__和__len__这3个方法,它们的作用分别如下:

__init__:初始化类

__get_item__:提取Dataset中的元素,通常是元组形式,如(input,target)

__len__:在对Dataset取len时,返回Dataset中的元素个数

IterableDataset是一个迭代器,需要重写__iter__方法,通过__iter__方法获得下一条数据。(这个目前没有遇到,待深入研究)

2.DataLoader

DataLoader提供了将数据整合成一个个批次的方法,用于进行模型批量运算。DataLoader中有如下几个需要注意的参数:

batch_size:一个批次数据中的样本数量

shuffle:打扰数据,避免模型陷入局部最优的情况,在定义了sampler之后,这个参数就无法使用了

sampler:采样器,如果有特殊的数据整合需求,可以自定义一个sampler,在sampler中返回每个批次的数据下标列表

pin_memory:将数据传入CUDA的Pinned Memory,方便更快的传入GPU中

collate_fn:进一步处理打包sampler筛选出来的一组组数据

num_workers:采用多进程方式加载,如果CPU能力较强,可以选择这种方法

drop_last:在样本总数不能被批次大小整除的情况下,最后一个批次的样本数量可能会与前面的批次不一致,若模型要求每个批次样本数量一致,可以将drop_last设置为True

你可能感兴趣的:(python,深度学习,计算机视觉)