scikit-learn 随机森林代码学习--乳腺癌检测

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

data = load_breast_cancer()

rfc = RandomForestClassifier(n_estimators=100, random_state=90)
score_pre = cross_val_score(rfc, data.data, data.target, cv=10).mean()
print(score_pre)

scorel = []
for i in range(0,200,10):
    rfc = RandomForestClassifier(n_estimators=i+1, n_jobs=-1)
    rfc_s = cross_val_score(rfc, data.data, data.target, cv=10).mean()
    scorel.append(rfc_s)
print(max(scorel), (scorel.index(max(scorel)))*10 + 1)
plt.figure(figsize=[20,5])
plt.plot(range(1,201,10),scorel)
plt.show()

scikit-learn 随机森林代码学习--乳腺癌检测_第1张图片

参数继续细调

#同样的方法,在131附近继续搜索
param_grid = {'max_depth':np.arange(1,20,1)}
rfc = RandomForestClassifier(n_estimators=39, random_state=90)
GS=GridSearchCV(rfc, param_grid,cv=10)
GS.fit(data.data, data.target)
GS.best_params_
GS.best_score_

在该案例中,我们应该关注更好的Recall表现,保证所有的癌症患者都能够被检查出来。

你可能感兴趣的:(python,模式识别与机器学习)