sklearn决策树可视化

过去,关于sklearn决策树可视化的教程大部分都是基于Graphviz(一个图形可视化软件)的。

Graphviz的安装比较麻烦,并不是通过pip install就能搞定的,因为要安装底层的依赖库。

现在,自版本0.21以后,scikit-learn也自带可视化工具了,它就是sklearn.tree.plot_tree()

假设决策树模型(clf)已经训练好了,画图的代码如下:

def tree1(clf):
    fig = plt.figure()
    tree.plot_tree(clf)
    fig.savefig(os.path.join(fig_dir, "tree1.png"))

没有设置图像的相关参数,画出的树结构看不清树节点的信息。
sklearn决策树可视化_第1张图片

设置字体大小,把文字调大一点:

def tree2(clf):
    fig = plt.figure()
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree2.png"))

文字是放大了,树节点也随着增大了,但是画面很拥挤。
sklearn决策树可视化_第2张图片
那把画布调大一点:

def tree3(clf):
    fig = plt.figure(figsize=(35, 10))
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree3.png"))

大功告成!
sklearn决策树可视化_第3张图片

下面的代码包含数据读取、模型训练和画图,有注释,就不展开了。

关注【小猫AI】公众号,回复tree可以获取训练模型的数据哦。

# -*- coding: utf-8 -*-
"""
Description : sklearn决策树可视化(scikit-learn==0.24.2)。
Authors     : wapping
CreateDate  : 2022/2/7
"""
import os
import pandas as pd
from sklearn import tree
from matplotlib import pyplot as plt


def read_data(fp):
    """加载训练数据。"""
    data = pd.read_csv(fp, header=None)
    x = data[[0, 1]]    # 第0,1列为特征
    y = data[[2]]       # 第2列为标签
    return x, y


def tree1(clf):
    # 没有设置图像的相关参数,画出的树结构看不清树节点的信息
    fig = plt.figure()
    tree.plot_tree(clf)
    fig.savefig(os.path.join(fig_dir, "tree1.png"))


def tree2(clf):
    # 设置字体大小,树节点放大了,但是很拥挤
    fig = plt.figure()
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree2.png"))


def tree3(clf):
    # 同时设置字体大小和图像的大小,树结构正常显示
    fig = plt.figure(figsize=(35, 10))
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree3.png"))


if __name__ == '__main__':
    fig_dir = "data/plot_tree"      # 保存图片的目录
    data_path = "data/plot_tree_data.csv"   # 训练树模型的数据
    os.makedirs(fig_dir, exist_ok=True)

    # 读取训练数据
    x, y = read_data(data_path)

    # 训练决策树分类器
    clf = tree.DecisionTreeClassifier(min_samples_leaf=100, random_state=666)
    clf = clf.fit(x, y)

    # 画树结构并保存图片
    tree1(clf)
    tree2(clf)
    tree3(clf)

你可能感兴趣的:(机器学习,决策树,sklearn,机器学习)