(DecisionTreeClassifier)决策树可视化实例-鸢尾花数据分类 学习笔记

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pydotplus
from IPython.display import display, Image

iris = load_iris()
x = iris.data
y = iris.target

x_train,x_test,y_train,y_train=train_test_split(x,y,test_size=0.3,random_state=0)

dtc = DecisionTreeClassifier(criterion='entropy') #建立决策树对象
dtc.fit(x_train, y_train) #决策树拟合
y_test_pre = dtc.predict(x_test) #预测
#print(y_test_pre)

num = x.shape[0] #样本总数
num_train = x_train.shape[0] #训练集样本数目
num_test = num - num_train #测试集样本数目
acc = sum(y_test_pre == y_test) / num_test
print('The accuracy is ', acc)
Out:
The accuracy is 0.9777777777777777

#上面几行代码可以用下面一行代替
dtc.score(x_test, y_test)

0.9777777777777777

dot_data = export_graphviz(dtc,out_file=None,feature_names=iris.feature_names,class_names=iris.target_names,filled=True,rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('iris.png')
display(Image(graph.create_png())

(DecisionTreeClassifier)决策树可视化实例-鸢尾花数据分类 学习笔记_第1张图片

你可能感兴趣的:(机器学习入门)