针对数据不平衡的加权随机采样WeightedRandomSampler

目录

背景

WeightedRandomSampler使用

代码实现


背景

由于我们不能将大量数据一次性放入网络中进行训练,所以需要分批进行数据读取。这一过程涉及到如何从数据集中读取数据的问题,当遇到样本数据不平衡的问题,可以使用加权随机采样器WeightedRandomSampler根据各个样本数据的权重来平衡每个batch中各样本数据,提升模型的性能。

pytorch提供了Sampler基类与多个子类实现不同方式的数据采样。子类包含:

  • SequentialSampler(顺序采样)
  • RandomSampler(随机采样)
  • SubsetRandomSampler(子集随机采样)
  • WeightedRandomSampler(加权随机采样)

WeightedRandomSampler使用

官方解释:

针对数据不平衡的加权随机采样WeightedRandomSampler_第1张图片

WeightedRandomSampler中,采样的权重针对的是每⼀个样本,所以我们可以确定好每个类对应的权重,再⼀⼀对应到样本上。并且,权重其实就是⽐值,num_samples就是⼀次采样的数⽬,⾥⾯的⽐值其实就是权重的⽐值。

假设分类问题,分为3类。

sampler = WeightedRandomSampler(samples_weight,samples_num,replacement=True)

samples_weight的每一项代表该样本种类占总样本的比例的倒数。

samples_num 为我们想采集多少个样本,可以重复采集。假设为2000。

replacement=True 表示对数据有放回的采样

假设3类样本分布比例为 猫,狗,猪 = 0.1,0.2,0.7

Count = [0.1,0.2,0.7]

Weight = 1/Count = [10,5,1.43]

samples_weight内全是 10或5或1.43,是10代表该样本是猫...

假设samples_weight内样子是:

[10,5,5,1.43,1.43,1.43,1.43.......,10]

即表示:

位置0的权重是10

位置1的权重是5

位置2的权重是5

位置3的权重是1.43

....

最后一个的权重是10

10的采样次数最少,但是权重最大,1.43采样次数多但权重小,所以达到了样本平衡的效果。

代码实现

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.autograd import Variable

trsfm = transforms.Compose([
            transforms.ToTensor(),
        ])
train_dataset = datasets.ImageFolder(data_dir, trsfm)
train_sampler = creater_sampler(train_dataset)
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, shuffle=False, num_workers=0)

#按样本权重采样器
def creater_sampler(self,train_set):
        classes_idx = train_set.class_to_idx
        appear_times = Variable(torch.zeros(len(classes_idx), 1))
        for label in train_set.targets:
            appear_times[label] += 1
        classes_weight = (1./(appear_times / len(train_set))).view( -1)
        weight=list(map(lambda x:classes_weight[x],train_set.targets))
        num_sample=int(len(train_set))
        print("total:{}".format(num_sample)+",targets0:{},targets1:{}".format(appear_times[0],appear_times[1]))
        sampler = WeightedRandomSampler(weight, num_sample, replacement=True)
        return sampler

补充:
lambda x:classes_weight[x]  表示简易函数输入以x为变量输出classes_weight[x]值的函数

map()函数的原型是map(function,iterable,……),它的结果是返回一个列表,即本代码中train_set.targets作为x输入

list()–返回给定元素的列表

你可能感兴趣的:(深度学习,深度学习,cnn,pytorch,计算机视觉,人工智能)