使用WeightedRandomSampler处理类不平衡(PyTorch)

参考博客:

博客一:Address class imbalance easily with Pytorch | by Mastafa Foufa | Analytics Vidhya | Medium

播客二:Address class imbalance easily with Pytorch Part 2 | by Mastafa Foufa | Towards Data Science

类不平衡

如论文所给出的结论,处理类不平衡的主要方法是过采样。过采样应被应用至完全消除类不平衡,而优化的欠采样系数取决于不平衡的程度。与一些经典的机器学习模型不同,过采样不会导致CNN网络过拟合

假设数据集中包含两类: c l a s s 1 class_1 class1 c l a s s 2 class_2 class2,基于均匀分布,那么从 c l a s s 1 class_1 class1中随机采样得到的概率为
p ( x ∈ c l a s s i ) = # { c l a s s i } # { t r a i n } = N c l a s s i N t r a i n p(x\in class_i)=\frac{\#\{class_i\}}{\#\{train\}}=\frac{N_{class_i}}{N_{train}} p(xclassi)=#{train}#{classi}=NtrainNclassi

但是,实际可能二分类中,某一类数量远大于另一类
N c l a s s 1 ≫ N c l a s s 2 N_{class_1} \gg N_{class_2} Nclass1Nclass2
也就是说
p ( x ∈ c l a s s 1 ) ≫ p ( x ∈ c l a s s 2 ) p(x\in class_1) \gg p(x\in class_2) p(xclass1)p(xclass2)
如果使用该数据集来训练模型,那么模型看到 c l a s s 1 class_1 class1的机会远大于 c l a s s 2 class_2 class2,导致模型无法从 c l a s s 2 class_2 class2中学到有用的特征。

因此,我们应首先进行人工增强数据,即增强小类数据,使得
p ( x ∈ c l a s s 1 ) ≈ p ( x ∈ c l a s s 2 ) p(x\in class_1) \approx p(x\in class_2) p(xclass1)p(xclass2)

使用 WeightedRandomSampler

博客一以二分类(比例为9:1),给出的源代码为

使用WeightedRandomSampler处理类不平衡(PyTorch)_第1张图片

处理前和处理后的每个batch中的类分布

使用WeightedRandomSampler处理类不平衡(PyTorch)_第2张图片

由该函数的Pytorch源代码可以看出,关键思想为,由控制参数的多项式分布中进行样本采样。

使用WeightedRandomSampler处理类不平衡(PyTorch)_第3张图片

Pytorch使用多项式分布,其参数为weights, number of samples,以及采样是否放回的replacement.

使用WeightedRandomSampler处理类不平衡(PyTorch)_第4张图片

Pytorch中引入的关键思想为基于多项式分布来从一组点中进行采样。每个样本被赋予采样的概率。该概率由其类权重参数来定义。

一个简单的例子

假设数据集具有以下形式,左边为100个样本,中间为类分布,右边为WeightedRandomSampler赋予的权重参数。蓝色为大类,红色为小类。

使用WeightedRandomSampler处理类不平衡(PyTorch)_第5张图片

我们可以控制权重,对小类给予更大的权重:
W 1 ≫ W 0 W_1 \gg W_0 W1W0
权重参数设置如下:
W 0 = N N 0 = 100 90 ≈ 1.11 W_0=\frac{N}{N_0}=\frac{100}{90} \approx 1.11 W0=N0N=901001.11

W 1 = N N 1 = 100 10 = 10 W_1 = \frac{N}{N_1}=\frac{100}{10}=10 W1=N1N=10100=10

使用类似softmax函数的方法来正规化权重矢量来得到采样概率
p ( c 0 ) = 90 W 0 ( 10 W 0 + 90 W 1 ) ≈ 0.0056 p(c_0)=\frac{90W_0}{(10W_0+90W_1)}\approx0.0056 p(c0)=(10W0+90W1)90W00.0056

p ( c 1 ) = 10 W 1 ( 10 W 0 + 90 W 1 ) = 0.05 p(c_1)=\frac{10W_1}{(10W_0+90W_1)}=0.05 p(c1)=(10W0+90W1)10W1=0.05

注意:我们必须在整个数据集上进行正规化。目的是为了矢量的元素和等于1

使用WeightedRandomSampler处理类不平衡(PyTorch)_第6张图片

接下来,我们从数学上来证明,100次随机采样,可以从两个类中分别采样到50和50样本。

从类 c 1 c_1 c1中采样到的样本数为
E [ c 1 ] = ∑ i = 91 100 m ∗ p ( c 1 ) = ∑ i = 91 100 100 ∗ 0.05 = 50 E[c_1]=\sum_{i=91}^{100}{m*p(c_1)}=\sum_{i=91}^{100}{100*0.05}=50 E[c1]=i=91100mp(c1)=i=911001000.05=50
从类 c 0 c_0 c0中采样到的样本数为
E [ c 0 ] = ∑ i = 1 90 m ∗ p ( c 0 ) = ∑ i = 1 90 100 ∗ 0.0056 ≈ 50.4 E[c_0]=\sum_{i=1}^{90}{m*p(c_0)}=\sum_{i=1}^{90}{100*0.0056} \approx 50.4 E[c0]=i=190mp(c0)=i=1901000.005650.4

你可能感兴趣的:(深度学习-分类,pytorch,深度学习)