机器学习五(sklearn决策树——多分类)

1.前言

sklearn决策树分类,采用ID3算法,自带iris数据集(根据草的特征进行分类,有3类,用0、1、2标记)。

2.决策树绘制准备

(1)下载安装graphviz
https://graphviz.gitlab.io/_pages/Download/Download_windows.html
(2)pycharm install graphviz
File->setting->project(project interpreter)->右侧绿+->查询安装
(3)决策树方法参数说明
http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

2.python代码

(1) 执行代码如下tree_class.py:

from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
import graphviz
import os

def multi_class_tree():
    iris = load_iris()
    x = iris['data']
    y = iris['target']
    dtc = tree.DecisionTreeClassifier(criterion="entropy")
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1)
    clf = dtc.fit(x_train, y_train)
    print(clf.predict(x_test))
    print(y_test)
    dot_data = tree.export_graphviz(clf, out_file=None)
    graph = graphviz.Source(dot_data)
    os.environ["PATH"] += os.pathsep + 'F:/Program Files/Graphviz2.38/bin/'
    graph.render("iris", view=True)

multi_class_tree()

3.验证结果

(1)测试集结果与预测结果比较
预测:[0 1 2 2 2 1 1 1 0 2 1 2 0 1 0]
实际:[0 1 2 2 2 1 1 2 0 2 1 2 0 1 0]
(2)决策树
机器学习五(sklearn决策树——多分类)_第1张图片

你可能感兴趣的:(机器学习)