数据不平衡:下采样、上采样python代码实现

一、下采样

所有数据存在DataFrame对象df中。数据分为两类:多数类别和少数类别,数据量相差大。数据预处理已将多数类别的Label标记为1,少数类别的Label标记为0。从多数类中随机抽取样本(抽取的样本数量与少数类别样本量一致)从而减少多数类别样本数据,使数据达到平衡的方式。

import numpy as np
import pandas as pd


def lower_sample_data(df, percent=1):
    '''
    percent:多数类别下采样的数量相对于少数类别样本数量的比例
    '''
    data1 = df[df['Label'] == 1]  # 将多数类别的样本放在data1
    data0 = df[df['Label'] == 0]  # 将少数类别的样本放在data0
    index = np.random.randint(
        len(data1), size=percent * (len(df) - len(data1)))  # 随机给定下采样取出样本的序号
    lower_data1 = data1.iloc[list(index)]  # 下采样
    return(pd.concat([lower_data1, data0]))

示例:

np.random.seed(28)
arr1 = np.random.randint(6, size=(100, 5))
arr2 = np.random.randint(1000, 1010, size=(10, 5))
columns = ['A', 'B', 'C', 'D', 'E']
df1 = pd.DataFrame(arr1, columns=columns)
df1['Label'] = 1
df2 = pd.DataFrame(arr2, columns=columns)
df2['Label'] = 0
df = pd.concat([df1, df2])
print(lower_sample_data(df))

输出:

       A     B     C     D     E  Label
37     4     3     0     1     4      1
41     5     5     5     4     4      1
35     5     3     2     2     5      1
69     0     0     1     0     4      1
98     2     4     5     2     0      1
78     3     3     2     4     3      1
52     2     2     1     3     3      1
43     0     0     5     5     4      1
61     5     0     1     0     5      1
86     3     2     0     1     4      1
0   1002  1005  1004  1005  1002      0
1   1007  1009  1005  1000  1003      0
2   1004  1005  1000  1003  1005      0
3   1002  1003  1000  1009  1003      0
4   1000  1002  1005  1009  1006      0
5   1001  1009  1003  1007  1003      0
6   1009  1004  1005  1007  1002      0
7   1008  1006  1009  1009  1009      0
8   1003  1007  1006  1007  1005      0
9   1001  1008  1003  1008  1003      0

 

二、上采样

和欠采样采用同样的原理,通过抽样来增加少数样本的数目,从而达到数据平衡的目的。一种简单的方式就是通过有放回抽样,不断的从少数类别样本数据中抽取样本,然后使用抽取样本+原始数据组成训练数据集来训练模型;不过该方式比较容易导致过拟合,一般抽样样本不要超过50%。

因为在上采样过程中,是进行是随机有放回的抽样,所以最终模型中,数据其实是相当于存在一定的重复数据,为了防止这个重复数据导致的问题,我们可以加入一定的随机性,也就是说:在抽取数据后,对数据的各个维度可以进行随机的小范围变动,eg: (1,2,3) --> (1.01, 1.99, 3);通过该方式可以相对比较容易的降低上采样导致的过拟合问题。

 

 

你可能感兴趣的:(python,机器学习)