torch.utils.data.DataLoader()详解

由于中文文档里面没有写这个类但是我们经常用它,所以这里进行一下分析

官网链接

类定义

torch.utils.data.DataLoader()详解_第1张图片

参数

torch.utils.data.DataLoader()详解_第2张图片

额外信息

torch.utils.data.DataLoader()详解_第3张图片

使用方法以及要点

不用sampler

# 训练数据集的加载器,自动将数据分割成batch,顺序随机打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                            drop_last = True ,      
                                           shuffle=True)

使用sampler

 首先,我们定义下标数组indices,它相当于对所有test_dataset中数据的编码
# 然后定义下标indices_val来表示校验集数据的那些下标,indices_test表示测试集的下标
indices = range(len(test_dataset))
indices_val = indices[:5000]
indices_test = indices[5000:]

# 根据这些下标,构造两个数据集的SubsetRandomSampler采样器,它会对下标进行采样
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)

# 根据两个采样器来定义加载器,注意将sampler_val和sampler_test分别赋值给了validation_loader和test_loader
validation_loader = torch.utils.data.DataLoader(dataset =test_dataset,
                                                batch_size = batch_size,
                                                sampler = sampler_val
                                               )
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          sampler = sampler_test
                                         )

特别注意

可能出现batch_size小于预期的情况,请指定drop_last = True解决

你可能感兴趣的:(GAN)