在现实环境中,采集的数据(建模样本)往往是比例失衡的。比如:一个用于模型训练的数据集中,A 类样本占 95%,B 类样本占 5%。
类别的不平衡会影响到模型的训练,所以,我们需要对这种情况进行处理。处理的主要方法如下:
class_weight=“balanced” 参数 根据样本出现的频率自动给 样本设置 权重
# class_weight="balanced" 参数 根据样本出现的评论自动给 样本设置 权重
logistic_regression = LogisticRegression(random_state=0, class_weight="balanced")
model = logistic_regression.fit(features_standardized, target)
我们使用的工具如下:
pip install imbalanced-learn
示例代码:
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
from collections import Counter
# 随机过采样
def test01(X, y):
from imblearn.over_sampling import RandomOverSampler
# 构建随机过采样对象
ros = RandomOverSampler(random_state=0)
# 对X中的少数样本进行随机过采样,返回类别平衡的数据集
X_resampled, y_resampled = ros.fit_resample(X, y)
# 查看新数据集类别比例
print(Counter(y_resampled))
# 数据可视化
plt.title("过采样数据集")
plt.scatter(X_resampled[:, 0], X_resampled[:, 1], c=y_resampled)
plt.show()
# 合成少数过采样
def test02(X, y):
from imblearn.over_sampling import SMOTE
# 构建 SMOTE 对象
ros = SMOTE(random_state=0)
# 对X中的少数样本进行合成少数过采样,返回类别平衡的数据集
X_resampled, y_resampled = ros.fit_resample(X, y)
# 查看新数据集类别比例
print(Counter(y_resampled))
# 数据可视化
plt.title("过采样数据集")
plt.scatter(X_resampled[:, 0], X_resampled[:, 1], c=y_resampled)
plt.show()
if __name__ == "__main__":
# 构建数据集
X, y = make_classification(n_samples=5000,
n_features=2,
n_informative=2,
n_redundant=0,
n_repeated=0,
n_redundant 特征
n_classes=3,
n_clusters_per_class=1,
weights=[0.01, 0.05, 0.94],
random_state=0)
# 统计各类别样本数量
print(Counter(y))
# 数据可视化
plt.title("类别不平衡数据集")
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
# 随机过采样
test01(X, y)
# 合成少数过采样
test02(X, y)
随机欠采样: 随机减少多数类别样本数量, 达到样本数量平衡.
示例代码:
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
from collections import Counter
def test(X, y):
from imblearn.under_sampling import RandomUnderSampler
# 构建随机欠采样对象
ros = RandomUnderSampler(random_state=0)
# 对X中的少数样本进行随机过采样,返回类别平衡的数据集
X_resampled, y_resampled = ros.fit_resample(X, y)
# 查看新数据集类别比例
print(Counter(y_resampled))
# 数据可视化
plt.title("过采样数据集")
plt.scatter(X_resampled[:, 0], X_resampled[:, 1], c=y_resampled)
plt.show()
if __name__ == "__main__":
# 构建数据集
X, y = make_classification(n_samples=5000,
n_features=2,
n_informative=2,
n_redundant=0,
n_repeated=0,
n_redundant 特征
n_classes=3,
n_clusters_per_class=1,
weights=[0.01, 0.05, 0.94],
random_state=0)
# 统计各类别样本数量
print(Counter(y))
# 数据可视化
plt.title("类别不平衡数据集")
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.show()
# 随机欠采样
test(X, y)