全文共5320字,预计学习时长20分钟
简介
如今,使用具有数百(甚至数千)个特征的数据集已然十分普遍了。如果这些特征数量与数据集中存储的观察值数量相差无几(或者前者比后者更多)的话,很可能会导致机器学习模型过度拟合。为避免此类问题的发生,需采用正则化或降维技术(特征提取)。在机器学习中,数据集的维数等于用来表示它的变量数。
使用正则化当然有助于降低过度拟合的风险,但使用特征提取技术也具备一定的优势,例如:
· 提高准确性
· 降低过度拟合风险
· 提高训练速度
· 提升数据可视化能力
· 提高模型可解释性
特征提取旨在通过在现有数据集中创建新特征(并放弃原始特征)来减少数据集中的特征数量。这些新的简化特征集需能够汇总原始特征集中的大部分信息。这样便可以从整合的原始特征集中创建原始特征的简化版本。
特征选择也是一种常用的用来减少数据集中特征数量的技术。它与特征提取的区别在于:特征选择旨在对数据集中现有特征的重要性进行排序,放弃次重要的特征(不创建新特征)。
本文将以 Kaggle MushroomClassification Dataset为例介绍如何应用特征提取技术。本文的目标是通过观察给定的特征来对蘑菇是否有毒进行预测。
首先,需导入所有必需的数据库。
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import seaborn as sns
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.ensemble import RandomForestClassifier
extraction17.py hosted with ❤ by GitHub
下图为本例中将采用的数据集。
将这些数据输入机器学习模型之前,将数据划分为特征(X)和标签(Y)以及独热码所有的分类变量。
X = df.drop(['class'], axis=1)
Y = df['class']
X = pd.get_dummies(X, prefix_sep='_')
Y = LabelEncoder().fit_transform(Y)
X = StandardScaler().fit_transform(X)
extraction15.py hosted with ❤ by GitHub
接着,创建一个函数(forest_test),将输入数据分成训练集和测试集,训练和测试一个随机森林分类器。
defforest_test(X, Y):
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y,
test_size=0.30,
random_state=101)
start = time.process_time()
trainedforest = RandomForestClassifier(n_estimators=700).fit(X_Train,Y_Train)
print(time.process_time() - start)
predictionforest = trainedforest.predict(X_Test)
print(confusion_matrix(Y_Test,predictionforest))
print(classification_report(Y_Test,predictionforest))
extraction14.py hosted with ❤ by GitHub
现在可以首先将该函数应用于整个数据集,然后再连续使用简化的数据集来比较二者的结果。
forest_test(X, Y)
extraction16.py hosted with ❤ by GitHub
如下图所示,使用这整个特征集训练随机森林分类器,可在2.2秒左右的训练时间内获得100%的准确率。在下列示例中,第一行提供了训练时间,供您参考。
2.2676709799999992
[[1274 0]
[ 0 1164]]
precision recall f1-score support
0 1.00 1.00 1.00 1274
1 1.00 1.00 1.00 1164
accuracy 1.00 2438
macro avg 1.00 1.00 1.00 2438
weighted avg 1.00 1.00 1.00 2438