torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')
Pytorch
中在使用DataLoader
函数时要传入三个重要参数dataset
, batch_size
和sampler
。
dataset
: 是数据集batch_size
: 是一次要喂入的参数数量sampler
:是从dataset
中取数据的策略Pytorch给出了集中常见的sampler:SequentialSampler
, RandomSampler
, SubsetRandomSampler
, WeightedRandomSampler
。
如果DataLoader
不指定sampler的话,它就会按顺序依次喂入数据,例如:
import torch
import numpy as np
a = torch.from_numpy(np.arange(10,20))
dataloader = torch.utils.data.DataLoader(a, batch_size=3, shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)
输出的数据为:
0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])
更多详细信息参考官方文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
SequentialSampler
和不指定sampler
一样,都是按顺序产生batch data:
sampler = torch.utils.data.SequentialSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)
输出为:
0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])
我们可以将sampler打印出来看看里面输出的是什么:
for i in sampler:
print(i)
可以看到输出的是数据集a的索引:
0
1
2
3
...
9
RandomSampler
会打乱daat的顺序随即输出batch
sampler = torch.utils.data.RandomSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)
输出结果为:
0 tensor([14, 17, 11])
1 tensor([18, 15, 12])
2 tensor([19, 13, 16])
3 tensor([10])
SubsetRandomSampler
会用来产生数据的子集,需要自己生成随即的indices。
import numpy.random as random
n_a = len(a)
indices = random.permutation(list(range(n_a)))
sampler = torch.utils.data.SubsetRandomSampler(indices[:8])
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)
输出结果为:
0 tensor([16, 10, 14])
1 tensor([11, 17, 15])
2 tensor([13, 19])
WeightedRandomSampler
给每个样本分配不同的权重:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
除了在DataLoader
中定义batch_size以外,还可以使用BatchSampler
来确定batch_size。例如:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
from torch.utils.data import SequentialSampler, BatchSampler
n_a = len(a)
sampler = BatchSampler(SequentialSampler(range(n_a)), batch_size=3, drop_last=False)
dataloader = torch.utils.data.DataLoader(a, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)
结果为(为啥是二维的?):
0 tensor([[10, 11, 12]])
1 tensor([[13, 14, 15]])
2 tensor([[16, 17, 18]])
3 tensor([[19]])
打印出sampler我可以看到每个输出是一个batch_size的索引:
for i in sampler:
print(i)
输出:
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]