17. 决策树参数实例

使用鸢尾花数据集构建决策树

决策树参数

  • 1 熵/Gini系数作为评判指标
  • 2 splitter: 所有或随机在特征中找最好的切分点
  • 3 max_features: None(所有值), log2, sqrt, N, 特征小于50时一般选择所有
  • 4 max_depth: 达到此最大深度时停止
  • 5 min_samples_split: 达到此值时停止分裂
  • 6 min_samples_leaf: 限制叶子节点最小样本数,小于该值则被剪枝,过大则被分裂
  • 7 min_weight_fraction_leaf: 叶子节点的权重项,表示叶子节点所有样本权重和的最小值, 当权重项与样本个数的乘积<阈值则剪掉,不再分裂
  • 8 max_leaf_nodes: 超过该值不再分裂,默认None
  • 9 class_weight: 样本各类别的权重,样本数*该值>某值,则剪掉
  • 10 min_impurity_split: 某节点不纯度<该阈值则该节点不再分裂,成为叶子节点

代码

# 决策树上解决过拟合:剪枝
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn import preprocessing

%matplotlib inline 
# 在jupyter notebook或jupyter qtconsole中用到,
# 且在plot()绘制图或figure()创建画板的时候可直接在python console中生成图像
import seaborn as sb

# 载入鸢尾花数据集
iris_data = pd.read_csv('iris.csv')
# iris2 = load_iris()
# print('iris2:', iris2)
# print('iris: ' ,iris_data)

# 查看iris_data所有属性或方法
# print('dir: ' ,dir(iris_data))
# 查看数据集简介
# print('DESCR: ', iris_data.DESCR)
# feature_names = iris_data.feature_names
# print('feature_names: ', feature_names)

iris_data.columns = ['slength', 'swidth', 'plength', 'pwidth', 'class']
#  array去前五行数据的方法 < - >pandas:dataFrame.head()
# head = iris.data[:20]
head = iris_data.head()
print('head: ', head)


img = Image.open('test.jpg')
plt.imshow(img)
plt.show()

# 统计信息
desc = iris_data.describe()
print('iris_data describe: ', desc)



# seaborn很好用的画图包
# 查看特征与特征之间的关系
# iris_data: dataFrame, 不能有缺失值: dropna()去掉缺失值
sb.pairplot(iris_data.dropna(), hue = 'class')


# violin图可用于分类任务
# 创建画白色画板
plt.figure()
for column_index, column in enumerate(iris_data.columns):
    if column == 'class':
        continue
    plt.subplot(2, 2, column_index + 1)
    #  x、y、 包含属性和类别数据集
    sb.violinplot(x = 'class', y = column, data = iris_data)
    

    
# 构建决策树
# from sklearn import datasets
# X, y = datasets.load_iris()['data'], datasets.load_iris()['target']
# print("X :",X)
# print("y :", y)
# # 标准化数据
# all_inputs = preprocessing.scale(X)

all_inputs = iris_data[['slength', 'swidth', 'plength', 'pwidth']].values
all_classes = iris_data['class'].values



# 交叉验证    
(training_inputs, testing_inputs, training_classes, testing_classes) = train_test_split(
    all_inputs, all_classes, train_size = 0.75, random_state = 1)

clf = DecisionTreeClassifier(random_state = 0)

# train
clf.fit(training_inputs, training_classes)

# test
score = clf.score(testing_inputs, testing_classes)
print('score: ', score)

# predict
print(testing_inputs[0])
print(clf.predict([testing_inputs[0]]))
# print(clf.predict(testing_inputs))
print('classes: ', clf.classes_)

查看前五行数据集:

image.png

显示鸢尾花图片:


image.png

鸢尾花数据集统计信息


image.png

使用seaborn查看特征之间的关系:


image.png

或可用于分类的seaborn图:


image.png

使用决策树的预测结果如下:


image.png

你可能感兴趣的:(17. 决策树参数实例)