DataLoader

DataLoader

参数:
DataLoader_第1张图片
常用参数:
dataset(Dataset):之前dataset类型的数据集
batch_size(int):每个batch的大小
shuffle(bool):是否打乱
num_workers(int): 多少子进程同时运行,0代表主进程,windows下如果出现brokenpipe情况,改为0
drop_last(bool): 是否丢掉最后一个不够batch_size的batch,True为丢掉
sampler: 采样的方法,如果使用了就不能指定shuffle参数,默认情况下shuffle=True对应随机采样、shuffle=False对应顺序采样
batch_sampler:与sampler相同,但返回的是一个batch个数的索引

运行示例:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10('./data', train=False, transform= torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter('dataloader')

# epoch循环用来查看shuffle的效果
for epoch in range(2):
    # 正常读取就是对DataLoader做for循环
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images('epoch: {}'.format(epoch), imgs, step)
        step += 1

writer.close()

DataLoader的实现

DataLoader_第2张图片

整体过程是,sampler规定抽取单个索引的方法,batchsampler调用sampler抽取batchsize个索引,
在dataloader的__next__中调用batchsampler取出batchsize个的索引所对应的图片和标签并返回

相当于dataloader根据dataset的getitem函数,按照batchsize大小进行分组,然后按组返回数据和标签。
具体细节可参考

你可能感兴趣的:(pytorch,深度学习,人工智能)