Pytorch:Dataset和DatasetLoader

学习内容

  • 深度学习数据加载流程:
  • 数据加载核心类:torch.utils.data.DataLoader
  • 数据集:torch.utils.data.Dataset
  • 采样器sampler:torch.utils.data.sampler.Sampler

深度学习数据加载流程:

# 创建自定义数据集, 建立索引到数据样本的映射
myData = MyDataset(**args)
# 通过DataLoader()加载myData,以指定方式从数据集中迭代生成 batch 样本集合
dataLoader = DataLoader(dataset, batch_size, shuffle, num_works)
# DataLoader为模型提供训练数据,根据sampler指定的策略生成训练样本
for i in range(epoch):
	for idx, (sequence, ans) in enumerate(dataLoader):
		pass

数据加载核心类:torch.utils.data.DataLoader

# DataLoader的构造函数参数配置
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False) 
  • dataset (Dataset) – 从中加载数据的数据集。
  • batch_size (int, optional) – 每批要加载多少个样本(默认值:1)。
  • shuffle (bool, optional) – 设置为 True时, 让数据在每个 epoch 重新洗牌(默认值:False)。
  • sampler(Sampler 或 Iterable,可选)——定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定,则不得指定 shuffle。
  • batch_sampler(Sampler 或 Iterable,可选)- 类似于 sampler,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
  • num_workers (int, optional) – 用于数据加载的子进程数。0 表示数据将在主进程中加载​​。(默认值:0)。

数据集:torch.utils.data.Dataset

表示数据集的抽象类,任何自定义的数据集都要继承这个类,并重写相关方法。
Pytorch支持两种不同类型的数据集:

  • 映射类型数据集

所有映射类型数据集子类都应该重写__getitem__(self, index) ,支持获取给定键的数据样本。重写__len__(self) ,返回数据集的大小;表示索引到数据样本的映射。

class MyDataset(Dataset):#需要继承Dataset
    def __init__(self):
        # 初始化文件路径,文件名称
        pass
    def __getitem__(self, index):
        # 读取数据 读取的是一个样本,而不是全部数据
        # 数据预处理
        # 返回data pair
        pass
    def __len__(self):
        # 返回数据集大小
        return len(self.data)
  • 迭代类型数据集

    Mark

采样器sampler:torch.utils.data.sampler.Sampler

PyTorch提供的Sampler

  • SequentialSampler
  • RandomSampler
  • SubsetRandomSampler
  • WeightedRandomSampler
    详见: https://www.csdn.net/

也可以自己定义采样器:自定义时要继承 torch.utils.data.sampler.Sampler 抽象类。
我们在训练时常用的是对批量数据训练,而BatchSampler的作用就是将前面的Sampler采样得到的索引值合并成一个batch并返回。


你可能感兴趣的:(Pytorch散记,pytorch,深度学习,python)