大家好,我是菜鸟君,之前跟大家聊过R语言的随机森林建模,指路 R语言 | 随机森林建模实战(代码+详解),作为刚过完1024节日的码农算法工程师来说,怎么可能只会用一种语言呢?今天就来说说Python怎么进行随机森林的模型构建。
首先,加载一些我们需要的库。需要注意的是,我们想进行分类预测,所以加载的是随机森林分类功能。如果想进行回归预测,需要加载随机森林回归功能哦(RandomForestRegression)。
# -*- coding: utf-8 -*-from sklearn.datasets import load_iris #iris数据集from sklearn.ensemble import RandomForestClassifier #随机森林分类from sklearn import metrics #模型结果指标库import pandas as pd import numpy as npimport matplotlib.pyplot as plt #画图
然后,我们加载数据集。今天用到的是自带的鸢尾花数据集。
# 还是用鸢尾花数据集吧iris=load_iris()#iris的4个属性,标签是花的种类print (iris['target'].shape)
建模这就来了!这一步,我们一共构建了两个模型,rf1采用默认参数,也就是括号里为空。rf2指定了一些参数。都是用前130个数据作为训练集。我们可以对比下两者的结果有啥差异。
## 随机森林分类器rf1 = RandomForestClassifier()rf2 = RandomForestClassifier(n_estimators=10, max_depth=None,min_samples_split=3, random_state=0)rf1.fit(iris.data[:130],iris.target[:130]) #用前130个数据作为训练集rf2.fit(iris.data[:130],iris.target[:130]) #用前130个数据作为训练集
模型构建结束以后,就轮到测试集了。测试集与训练集一定是分开、独立的。不然就没法客观衡量模型的预测效果啦。在这里,我们的测试集是从131行到150行的鸢尾花数据集。
testset=iris.data[131:150]y_true = iris.target[131:150]
接下来就是把两个随机森林模型在同一个测试集上进行预测,分别打印出预测的结果。
y_pre1 = rf1.predict(testset)y_pre2 = rf2.predict(testset)print ('rf1 prediction;',y_pre1)print ('rf2 prediction;',y_pre2)
接下来就是结果部分啦。
# Calculate metricsmse1 = metrics.mean_squared_error(y_true, y_pre1)mse2 = metrics.mean_squared_error(y_true, y_pre2)print("MSE1: %.4f" % mse1)print("MSE2: %.2f" % mse2)
从MSE结果看出,第一个模型rf1的均方误差比第二个模型rf2的均方误差要更大一点。说明rf1的结果比rf2要差。
Python sklearn的metrics库十分强大,这里只是用MSE这个指标做个示范,还有准确率、灵敏度、特异度等指标,只要你需要,基本都能从这个库里找到。
这是文档,给你们看一部分指标
最后,随机森林可以显示各特征对分类的贡献,即重要性评分。我们可以打印出重要性的分数,也可以基于这个分数来进行可视化展示。
feature_importance = rf2.feature_importances_# make importances relative to max importancefeature_importance = 100.0 * (feature_importance / feature_importance.max())print(feature_importance) #use inbuilt class feature_importances #plot graph of feature importances for better visualizationfeat_importances = pd.Series(rf1.feature_importances_, index=pd.DataFrame(iris.data).columns)feat_importances.nlargest(5).plot(kind='barh')plt.title('Variable Importance')plt.xlabel('Relative Importance')plt.show()
天气变冷了,只要代码不报错,内心就不会觉得冷,加油打工人!下个周末咱们再约,嗷~
交流QQ群:83837564
B站视频:谁说菜鸟不会数据分析
别忘了点“在看”