关于WeightedRandomSampler的用法csdn上有一些很棒的博客。本文参考博客Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样的代码对WeightedRandomSampler做进一步的分析。
首先从对官网给出的注释做进一步解释:
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
Parameters:
weights (sequence) – a sequence of weights, not necessary summing up to one(样本数量的倒数,如猫狗图片的如果有10张和20张,weights可设置为[0.67, 0.33])
num_samples (int) – number of samples to draw(还是用上述例子,此处为30)
replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.(后续例子重点讲这个参数。即是否重复取出样本以确保采样的均衡)
generator (Generator) – Generator used in sampling.(不用管)
接下来上代码:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([10, 50, 60])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 10
data = torch.randn(numDataPoints, data_dim)
for i in range(data.shape[0]):
data[i, 0] = i#把样本的第一个值赋值为样本的行号,这样输出这个行号就知道样本是哪一类了
#输出: 0~9行为第0类 ;10~59为第1类 ; 60~129行为第2类
target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
torch.ones(class_counts[1], dtype=torch.long),
torch.ones(class_counts[2], dtype=torch.long) * 2))
print('target train 0/1/2: {}/{}/{}'.format(
(target == 0).sum(), (target == 1).sum(), (target == 2).sum()))
# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
[(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])
# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=0, sampler=sampler)
# train_loader = DataLoader(
# train_dataset, batch_size=bs, num_workers=0, shuffle=True)
# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
print("batch index {}, 0/1/2: {}/{}/{}".format(
i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))
x_n = [el[0].tolist() for el in x]
print(sorted(x_n))
这个例子中,我们给第一类10个样本,第二类50个样本,第三类60个样本。同时我们给把每个样本的第0个值设为其所在tensor的行号,这样在Dataloader输出时就知道是输出的哪一类的样本了。
此时
replacement=True
上述代码运行结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 5/2/3
[1.0, 2.0, 4.0, 6.0, 8.0, 45.0, 45.0, 95.0, 101.0, 108.0]
batch index 1, 0/1/2: 1/5/4
[4.0, 30.0, 30.0, 39.0, 39.0, 54.0, 65.0, 67.0, 96.0, 97.0]
batch index 2, 0/1/2: 3/3/4
[4.0, 7.0, 9.0, 22.0, 43.0, 59.0, 90.0, 104.0, 111.0, 119.0]
batch index 3, 0/1/2: 4/3/3
[1.0, 3.0, 4.0, 8.0, 22.0, 49.0, 52.0, 64.0, 81.0, 89.0]
batch index 4, 0/1/2: 2/5/3
[3.0, 4.0, 10.0, 25.0, 42.0, 44.0, 49.0, 74.0, 100.0, 104.0]
batch index 5, 0/1/2: 1/5/4
[7.0, 22.0, 31.0, 32.0, 34.0, 37.0, 83.0, 113.0, 113.0, 115.0]
batch index 6, 0/1/2: 5/2/3
[2.0, 2.0, 3.0, 8.0, 9.0, 13.0, 20.0, 61.0, 75.0, 97.0]
batch index 7, 0/1/2: 3/3/4
[3.0, 7.0, 9.0, 31.0, 31.0, 38.0, 70.0, 71.0, 75.0, 91.0]
batch index 8, 0/1/2: 4/1/5
[2.0, 3.0, 4.0, 7.0, 23.0, 67.0, 71.0, 74.0, 95.0, 117.0]
batch index 9, 0/1/2: 3/3/4
[2.0, 3.0, 4.0, 17.0, 36.0, 56.0, 103.0, 104.0, 110.0, 115.0]
batch index 10, 0/1/2: 4/3/3
[3.0, 6.0, 7.0, 8.0, 18.0, 39.0, 43.0, 66.0, 95.0, 105.0]
batch index 11, 0/1/2: 2/1/7
[0.0, 2.0, 22.0, 77.0, 81.0, 100.0, 103.0, 103.0, 106.0, 111.0]
本例中样本数量差距比较悬殊,可以看到取出的每个batch样本数量尽可能接近均衡。同时对每一个batch而言,可能取出重复的样本,在不同的batch内,对于第一类,也被重复取出了。
接下来我们设置
replacement=False
运行结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 4/5/1
[5.0, 6.0, 7.0, 9.0, 16.0, 22.0, 28.0, 30.0, 35.0, 69.0]
batch index 1, 0/1/2: 2/4/4
[0.0, 1.0, 10.0, 20.0, 26.0, 38.0, 82.0, 90.0, 111.0, 119.0]
batch index 2, 0/1/2: 1/3/6
[4.0, 17.0, 19.0, 23.0, 60.0, 63.0, 64.0, 68.0, 79.0, 92.0]
batch index 3, 0/1/2: 2/2/6
[3.0, 8.0, 40.0, 48.0, 89.0, 96.0, 100.0, 102.0, 104.0, 115.0]
batch index 4, 0/1/2: 1/3/6
[2.0, 32.0, 37.0, 55.0, 84.0, 88.0, 94.0, 106.0, 109.0, 110.0]
batch index 5, 0/1/2: 0/5/5
[12.0, 25.0, 29.0, 33.0, 34.0, 72.0, 78.0, 98.0, 107.0, 113.0]
batch index 6, 0/1/2: 0/3/7
[42.0, 43.0, 49.0, 70.0, 83.0, 85.0, 93.0, 95.0, 97.0, 117.0]
batch index 7, 0/1/2: 0/5/5
[21.0, 27.0, 44.0, 57.0, 59.0, 67.0, 74.0, 80.0, 86.0, 114.0]
batch index 8, 0/1/2: 0/5/5
[15.0, 31.0, 47.0, 50.0, 53.0, 75.0, 77.0, 103.0, 105.0, 112.0]
batch index 9, 0/1/2: 0/4/6
[18.0, 45.0, 46.0, 54.0, 65.0, 71.0, 73.0, 76.0, 99.0, 108.0]
batch index 10, 0/1/2: 0/6/4
[11.0, 13.0, 14.0, 24.0, 51.0, 56.0, 61.0, 66.0, 87.0, 101.0]
batch index 11, 0/1/2: 0/5/5
[36.0, 39.0, 41.0, 52.0, 58.0, 62.0, 81.0, 91.0, 116.0, 118.0]
可以看到对同一个数据,只会取一次。无论在一个batch中还是在整个epoch中。由于不能重复取,在前几个batch中,样本还比较接近均衡,但是当第一类被取完了之后,就没有办法再取了。而第二类和第三类样本数量相差较小,因此在整体范围内接近均衡。
接下来我们不用WeightedRandomSampler,而是随机打乱样本,看看采样结果:
train_loader = DataLoader(
train_dataset, batch_size=bs, num_workers=0, shuffle=True)
此时的采样结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 0/5/5
[11.0, 17.0, 34.0, 46.0, 51.0, 67.0, 70.0, 78.0, 109.0, 111.0]
batch index 1, 0/1/2: 0/6/4
[10.0, 13.0, 19.0, 29.0, 38.0, 53.0, 66.0, 85.0, 88.0, 97.0]
batch index 2, 0/1/2: 1/4/5
[2.0, 21.0, 30.0, 41.0, 48.0, 65.0, 73.0, 81.0, 102.0, 116.0]
batch index 3, 0/1/2: 2/3/5
[4.0, 5.0, 14.0, 26.0, 35.0, 64.0, 89.0, 96.0, 98.0, 115.0]
batch index 4, 0/1/2: 2/2/6
[0.0, 3.0, 16.0, 52.0, 82.0, 83.0, 92.0, 106.0, 110.0, 113.0]
batch index 5, 0/1/2: 2/4/4
[6.0, 8.0, 18.0, 40.0, 57.0, 59.0, 63.0, 101.0, 108.0, 112.0]
batch index 6, 0/1/2: 0/6/4
[12.0, 27.0, 28.0, 37.0, 39.0, 56.0, 75.0, 77.0, 87.0, 103.0]
batch index 7, 0/1/2: 0/2/8
[15.0, 25.0, 61.0, 62.0, 76.0, 93.0, 99.0, 100.0, 104.0, 119.0]
batch index 8, 0/1/2: 3/4/3
[1.0, 7.0, 9.0, 36.0, 43.0, 45.0, 55.0, 71.0, 80.0, 86.0]
batch index 9, 0/1/2: 0/4/6
[22.0, 24.0, 44.0, 50.0, 69.0, 91.0, 105.0, 114.0, 117.0, 118.0]
batch index 10, 0/1/2: 0/5/5
[23.0, 31.0, 32.0, 42.0, 47.0, 60.0, 68.0, 74.0, 79.0, 107.0]
batch index 11, 0/1/2: 0/5/5
[20.0, 33.0, 49.0, 54.0, 58.0, 72.0, 84.0, 90.0, 94.0, 95.0]
可以看到样本被随机取出,在一个batch中,相比于replacement=False样本的不均衡程度更大。
在实际训练过程中,各位小伙伴可以根据自己需求,灵活调整参数(liandan)。
注:该方法多用于分类问题。即一个训练样本对应一个标签。对于分割问题,一个样本中有很多标签,用该方法就不太方便。分割问题推荐给损失函数添加权重,如nn.CrossEntropyLoss(weight=weight)。