决策树(Decision Tree):通俗易懂之介绍
sklearn.tree.DecisionTreeClassifier( criterion='gini',
splitter='best',
max_depth=None,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features=None,
random_state=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
min_impurity_split=None,
class_weight=None,
presort='deprecated',
ccp_alpha=0.0)
1、导入相关模块
import sklearn
from sklearn import datasets,tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
2、加载并划分数据
iris = datasets.load_iris()
iris_feature = iris.data
iris_target = iris.target
feature_train, feature_test, target_train, target_test =
train_test_split(iris_feature, iris_target, test_size=0.33,
random_state=0)
3、建立决策树模型、训练预测以及计算准确度
#'gini'对应CART算法,'entropy'对应ID3算法
dt_model = tree.DecisionTreeClassifier(criterion='gini', min_samples_leaf=3)
dt_model.fit(feature_train, target_train)
predict_results = dt_model.predict(feature_test)
print (predict_results)
print (target_test)
print (accuracy_score(predict_results, target_test))
[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
2 1 1 2 0 2 0 0 1 2 2 2 2]
[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
1 1 1 2 0 2 0 0 1 2 2 2 2]
0.98
4、绘制决策树
%matplotlib inline
fig=plt.figure(figsize=(10,10))
tree.plot_tree(dt_model,filled="True",
feature_names=["Sepal length","Sepal width",
"Petal length","Petal width"],
class_names=["setosa","versicolor","virginica"])