深度理解PyTorch的WeightedRandomSampler处理图像分类任务的类别不平衡问题

最近做活体检测任务,将其看成是一个图像二分类问题,然而面临的一个很大问题就是正负样本的不平衡问题,也就是正样本(活体)很多,而负样本(假体)很少,如何处理好数据集的类别不平衡问题有很多方法,如使用加权的交叉熵损失(nn.CrossEntropyLoss(weight=weight)),但是更加有效的一个实践是在模型训练的过程中过采样少数类样本,增加这些少数类样本被模型看到的频率。

pytorch提供了一个WeightedRandomSampler 帮助完成以上任务。

torch.utils.data — PyTorch 2.0 documentation

深度理解PyTorch的WeightedRandomSampler处理图像分类任务的类别不平衡问题_第1张图片

通用的使用方法如下:


[步骤1] class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc.
[步骤2] weights = 1 / torch.Tensor(class_sample_count)
[步骤3] # 将weights赋予所有的训练样本,作为每个训练样本的权重,标记为 samples_weight
[步骤4] sampler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True) 
[步骤5] trainloader = data_utils.DataLoader(train_dataset, batch_size = 20, sampler = sampler) 

特别注意: 第5步骤中,一旦设置了sampler, 就不能再设置shuffle,参考:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

怎么理解这些东西呢?

首先第一步需要计算所有类别的个数,这样可以知道哪些类别数量大,而哪些类别数量少。

其次,计算每个类别的权重,只需要简单的取每个类别的个数之倒数即可。

第三,对于每个训练样本,根据其标签,获取其对应类别的权重,作为训练时样本采样的概率。

第四,就是创建采样器,其中num_samples通常取为weights的个数,即训练样本的个数。

大家可以想想,对于样本少的类别,以二分类为例,

[假, 真, 真, 真, 真, 真]

假的权重 = 1/1 , 真的权重= 1/5, 

于是上面的samples_weight取为:

[1, 0.2,0.2, 0.2,0.2, 0.2],

因此取得假样本的期望值 = 1*1=1, 取得真样本的期望值=0.2*5=1, 这样二者就平衡了。因此假样本被过采样,相对地,真样本被欠采样了。

试验

深度理解PyTorch的WeightedRandomSampler处理图像分类任务的类别不平衡问题_第2张图片

 由此看来总体是多次采样的样本基本是真假平衡的。

那么也许很多人会问:

1. 按照上述概率采样方式,是否所有的训练样本都被模型看到了呢?

2. 如果我不想让数据集均匀分布,而是想达到其他比例呢?

关于这两个问题,博文《Demystifying PyTorch’s WeightedRandomSampler by example》给了一个很详细的回答。

我总结在这里:

1. 在一轮(epoch)训练中,确实可能存在部分样本没有被模型看到,增加num_samples 为训练数据集的样本数量的两倍,会使得一轮迭代过程中看到更多的图像,但是一般仍然推荐设置num_samples 为训练数据集的样本数量,并且相信,当我们训练更多轮以后,所有的图像都将在某一个点处被看到。

2.  对于类别不平衡的数据集,一般在9-10轮以后就会看全所有的样本,而对于类别均衡的数据集,采用上述方法采样,需要大致经过5轮才能看完所有的样本(这种情况下就不用采取这种采样策略了)。

3.  看博文吧。

深度理解PyTorch的WeightedRandomSampler处理图像分类任务的类别不平衡问题_第3张图片

注:该方法多用于分类问题。即一个训练样本对应一个标签。对于分割问题,一个样本中有很多标签,用该方法就不太方便。分割问题推荐给损失函数添加权重,如nn.CrossEntropyLoss(weight=weight)。

参考文献:

1. Demystifying PyTorch’s WeightedRandomSampler by example
https://gist.github.com/Chris-hughes10/260c70650c5a6f322d273a8a8728b91a

2. Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样

3. torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样

4. WeightedRandomSampler 理解了吧

你可能感兴趣的:(pytorch,深度学习,人工智能)