决策树是一种用于机器学习的监督算法。它使用一个二进制树形图(每个节点有两个孩子)为每个数据样本分配一个目标值,目标值呈现在树叶中。为了到达树叶,样本通过节点传播,从根节点开始。在每个节点中,决定它应该去哪个子孙节点。决定是根据所选样本的特征做出的。决策树学习是一个根据所选指标在每个内部树节点中寻找最佳规则的过程。这些都是老生常谈的问题了,希望大家简单了解一下即可。
import sklearn.datasets as datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from matplotlib import pyplot as plt
import pandas as pd
iris = datasets.load_iris()
X_df = pd.DataFrame(iris.data, columns = iris.feature_names)
print(X_df.head(15))
Y=iris.target
print("\nClass Labels for all the data points:\n", Y)
# 数据拟合
dtree = DecisionTreeClassifier() # (random_state=1234)
model=dtree.fit(X_df,Y)
text_representation = tree.export_text(dtree)
print(text_representation)
# 结果保存
with open("iris_DecisionTree_text.txt", "w") as fout:
fout.write(text_representation)
plot_tree方法在0.21版本中被添加到sklearn中,需要安装matplotlib。它允许我们很容易地产生树的图(不需要中间导出到graphviz)。
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(dtree,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True)
在plot_tree中使用fill=True:当这个参数被设置为True时,该方法使用颜色来表示大部分的类。(如果能有一些与类和颜色相匹配的图例就更好了)。
import six
import sys
sys.modules['sklearn.externals.six'] = six
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
print("Import Successful")
# 绘制决策树
dot_data = StringIO()
export_graphviz(dtree, out_file=dot_data,
feature_names = iris.feature_names,
filled = True, rounded = True,
special_characters = True, node_ids = True)
graph=pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
# 保存数据
graph.write_png("iris_DecisionTree_graphivz1.png")
import graphviz
# DOT format data
dot_data = tree.export_graphviz(dtree, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True)
# Draw Decision Tree
graph = graphviz.Source(dot_data, format="png") # change "png" to "pdf" for PDF format
graph
# 保存数据
graph.render("iris_DecisionTree_graphivz2")
from dtreeviz.trees import dtreeviz # remember to load the package
viz = dtreeviz(dtree, X_df, Y,
target_name="target",
feature_names=iris.feature_names,
class_names=list(iris.target_names))
viz
viz.save("iris_DecisionTree_dtreeviz.svg")