决策树可视化的两种方法

一、导入数据集并进行训练

# 引入数据集

from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier     
import matplotlib.pyplot as plt

iris = load_iris()
iris_feature = iris.data #特征数据
iris_target = iris.target #分类数据

print (iris.data)         
print (iris.target)       
clf = DecisionTreeClassifier()      
clf.fit(iris.data, iris.target)     


# 获取花卉两列数据集
X = iris.data
L1 = [x[0] for x in X]
L2 = [x[1] for x in X]
 

    #绘图
plt.scatter(X[:50, 0], X[:50, 1], color='red', marker='o', label='setosa')
plt.scatter(X[50:100, 0], X[50:100, 1], color='blue', marker='x', label='versicolor')
plt.scatter(X[100:, 0], X[100:, 1], color='green', marker='s', label='Virginica')
plt.title("DTC")
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')

plt.xticks(())
plt.yticks(())
plt.legend(loc=2)
plt.show()

import pandas
#导入数据集iris
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pandas.read_csv(url, names=names) 
print(dataset.describe())

dataset.hist()
plt.show()
 

二、决策树可视化,方法有两种

方法一:

with open('tree.dot','w') as f:
    f = tree.export_graphviz(clf, out_file=f,
                     feature_names=iris.feature_names,
                     class_names=iris.target_names,
                     filled=True, rounded=True,
                     special_characters=True)

运行代码后,在文件夹里会有一个tree.dot文件,安装graphviz(去官网下载graphviz安装文件),然后在命令行中编译 dot -Tpdf tree.dot -o tree.pdf。在文件夹里打开pdf文件就可以看到决策树图。

方法二:

直接生成pdf可视化结果。调试过程中遇到six包和pydot包的两个问题, 输入下面修改后的代码。

#from sklearn.externals.six import StringIO #新版本中可以直接从six包中导入StringIO
from six import StringIO

#import pydot #运行最后一行代码时,总是报错,把pydot库改为pydotplus库
import pydotplus  


dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
                     feature_names=iris.feature_names,
                     class_names=iris.target_names,
                     filled=True, rounded=True,
                     special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
print(len(graph))  # 1
print(graph)  # []
print(graph[0])  #
# graph.write_pdf("iris.pdf")
graph[0].write_pdf("iris.pdf")

graph[0].write_pdf()

你可能感兴趣的:(决策树,机器学习,python)