Pytorch中的sampler

Pytorch中常见sampler的使用

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')

Pytorch中在使用DataLoader函数时要传入三个重要参数dataset, batch_sizesampler

  • dataset: 是数据集
  • batch_size: 是一次要喂入的参数数量
  • sampler:是从dataset中取数据的策略

Pytorch给出了集中常见的sampler:SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
如果DataLoader不指定sampler的话,它就会按顺序依次喂入数据,例如:

import  torch
import  numpy  as  np
a = torch.from_numpy(np.arange(10,20))

dataloader = torch.utils.data.DataLoader(a, batch_size=3, shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出的数据为:

0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])

更多详细信息参考官方文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler

1. SequentialSampler

SequentialSampler和不指定sampler一样,都是按顺序产生batch data:

sampler = torch.utils.data.SequentialSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出为:

0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])

我们可以将sampler打印出来看看里面输出的是什么:

for i in sampler:
	print(i)

可以看到输出的是数据集a的索引:

0
1
2
3
...
9

2. RandomSampler

RandomSampler会打乱daat的顺序随即输出batch

sampler = torch.utils.data.RandomSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出结果为:

0 tensor([14, 17, 11])
1 tensor([18, 15, 12])
2 tensor([19, 13, 16])
3 tensor([10])

3. SubsetRandomSampler

SubsetRandomSampler会用来产生数据的子集,需要自己生成随即的indices。

import numpy.random as random
n_a = len(a)
indices = random.permutation(list(range(n_a)))
sampler = torch.utils.data.SubsetRandomSampler(indices[:8])
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出结果为:

0 tensor([16, 10, 14])
1 tensor([11, 17, 15])
2 tensor([13, 19])

4. WeightedRandomSampler

WeightedRandomSampler给每个样本分配不同的权重:

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

5. BatchSampler

除了在DataLoader中定义batch_size以外,还可以使用BatchSampler来确定batch_size。例如:

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
from torch.utils.data import SequentialSampler, BatchSampler
n_a = len(a)
sampler = BatchSampler(SequentialSampler(range(n_a)), batch_size=3, drop_last=False)
dataloader = torch.utils.data.DataLoader(a, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)

结果为(为啥是二维的?):

0 tensor([[10, 11, 12]])
1 tensor([[13, 14, 15]])
2 tensor([[16, 17, 18]])
3 tensor([[19]])

打印出sampler我可以看到每个输出是一个batch_size的索引:

for i in sampler:
print(i)

输出:

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]

你可能感兴趣的:(pytorch,深度学习)