Scikit-Learn 机器学习笔记 -- 决策树

Scikit-Learn 机器学习笔记 – 决策树


参考文档: handson-ml


import numpy as np


# 加载鸢尾花数据集
def load_dataset():
    from sklearn import datasets
    iris = datasets.load_iris()
    # print(iris)
    # 使用第3和第4个特征
    X = iris['data'][:, (2, 3)]
    # bool类型转为数值型
    y = iris['target']
    return X, y, iris


# 决策树分类器
def tree_classify(X, y):
    from sklearn.tree import DecisionTreeClassifier
    tree_clf = DecisionTreeClassifier(max_depth=2)
    tree_clf.fit(X, y)
    print(tree_clf.tree_)
    return tree_clf


# 绘制决策树图
def draw_tree(model, iris):
    from sklearn.tree import export_graphviz
    export_graphviz(
        model,
        out_file="iris_tree.dot",
        feature_names=iris.feature_names[2:],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )
    # import pydotplus
    # dot_data = export_graphviz(model, out_file=None)
    # graph = pydotplus.graph_from_dot_data(dot_data)
    # graph.write_pdf("iris.pdf")


# 预测
def tree_predict(model, sample):
    # 预测类别
    predict = model.predict(sample)
    # 属于各类别的概率
    predict_prob = model.predict_proba(sample)
    print('决策树预测类别为:', predict, '属于各类别的概率为:', predict_prob)


if __name__ == '__main__':
    # 加载数据集
    X, y, iris = load_dataset()
    # 创建决策树分类器
    tree_clf = tree_classify(X, y)
    # 绘制决策树
    # draw_tree(tree_clf, iris)
    # 预测
    tree_predict(tree_clf, [[5, 1.5]])

你可能感兴趣的:(机器学习,机器学习,sk-learn,决策树)