目录
背景
WeightedRandomSampler使用
代码实现
由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,当遇到样本数据不平衡的问题,可以使用加权随机采样器WeightedRandomSampler根据各个样本数据的权重来平衡每个batch中各样本数据,提升模型的性能。
pytorch提供了Sampler基类与多个子类实现不同方式的数据采样。子类包含:
官方解释:
WeightedRandomSampler中,采样的权重针对的是每⼀个样本,所以我们可以确定好每个类对应的权重,再⼀⼀对应到样本上。并且,权重其实就是⽐值,num_samples就是⼀次采样的数⽬,⾥⾯的⽐值其实就是权重的⽐值。
假设分类问题,分为3类。
sampler = WeightedRandomSampler(samples_weight,samples_num,replacement=True)
samples_weight的每一项代表该样本种类占总样本的比例的倒数。
samples_num 为我们想采集多少个样本,可以重复采集。假设为2000。
replacement=True 表示对数据有放回的采样
假设3类样本分布比例为 猫,狗,猪 = 0.1,0.2,0.7
Count = [0.1,0.2,0.7]
Weight = 1/Count = [10,5,1.43]
samples_weight内全是 10或5或1.43,是10代表该样本是猫...
假设samples_weight内样子是:
[10,5,5,1.43,1.43,1.43,1.43.......,10]
即表示:
位置0的权重是10
位置1的权重是5
位置2的权重是5
位置3的权重是1.43
....
最后一个的权重是10
10的采样次数最少,但是权重最大,1.43采样次数多但权重小,所以达到了样本平衡的效果。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.autograd import Variable
trsfm = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(data_dir, trsfm)
train_sampler = creater_sampler(train_dataset)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, shuffle=False, num_workers=0)
#按样本权重采样器
def creater_sampler(self,train_set):
classes_idx = train_set.class_to_idx
appear_times = Variable(torch.zeros(len(classes_idx), 1))
for label in train_set.targets:
appear_times[label] += 1
classes_weight = (1./(appear_times / len(train_set))).view( -1)
weight=list(map(lambda x:classes_weight[x],train_set.targets))
num_sample=int(len(train_set))
print("total:{}".format(num_sample)+",targets0:{},targets1:{}".format(appear_times[0],appear_times[1]))
sampler = WeightedRandomSampler(weight, num_sample, replacement=True)
return sampler
补充:
lambda x:classes_weight[x] 表示简易函数输入以x为变量输出classes_weight[x]值的函数
map()函数的原型是map(function,iterable,……),它的结果是返回一个列表,即本代码中train_set.targets作为x输入
list()–返回给定元素的列表