鸢尾花数据集,特征为连续值数据的决策树的多分类

1.导入工具

import pandas as pd
from sklearn import preprocessing
from sklearn import tree
from sklearn.datasets import load_iris

2.导入鸢尾花数据集,探索数据集
iris=load_iris()
#iris是一个字典,包含了数据、标签、标签名、数据描述等信息。可以通过键来索引对应值。
iris
#查看iris字典里的所有键
dir(iris)
iris.data
#150个数据,每个数据都有四个维度的特征,每个特征都是连续数值
iris.data.shape
#四个特征列名
iris.feature_names
#标签,0,1,2对应三种不同的鸢尾花
iris.target
#三种鸢尾花的名字
iris.target_names
鸢尾花数据集的描述说明信息
print(iris.DESCR)

3.构建决策树模型
dir(iris)
clf=tree.DecisionTreeClassifier(max_depth=4)
clf=clf.fit(iris.data, iris.target)
clf

4.可视化决策树
import pydotplus
from IPython.display import Image,display
dot_data=tree.export_graphviz(clf,
                             out_file=None,
                             feature_names=iris.feature_names,
                             class_names=iris.target_names,
                             filled=True,
                             rounded=True
                             )
graph=pydotplus.graph_from_dot_data(dot_data)
display(Image(graph.create_png()))


5.对整个训练集做预测
clf.predict(iris.data)

6.对单个样本做预测
#假设有一朵新的鸢尾花,四个特征分别为6.6cm,2.5cm,4.3cm,1,3cm。用训练好的决策树判断它属于哪一类鸢尾花。
import numpy as np
a1=np.array([6.6, 2.5, 4.3, 1.3])
a1
a1.shape
a1.reshape(1,-1).shape
clf.predict(a1.reshape(1,-1))
#属于第二类鸢尾花。
7.对多个样本做预测
a1=iris.data[30]
a2=iris.data[70]
a3=iris.data[120]
import numpy as np
b=np.row_stack((a1,a2,a3))
b
clf.predict(b)
import numpy as np
import matplotlib.pyplot as plt
%matplotlib.colors import ListedIormap
from matplotlib.colors import ListedColormap 
from sklearn import datasets
from sklearn import tree
iris=datasets.load_iris()
x=iris.data[:,2:4]#取出花瓣的长和宽
y=iris.target#取出标签
#计算散点图的上下界
x_min,x_max=x[:,0].min() -.5,  x[:,0].max()+.5
y_min,y_max=x[:,1].min() -.5,  x[:,1].max()+.5
#绘制边界
camo=cmap_light=ListedColormap(['#AAAAFF','#AAFFAA','#FFAAAA'])
h=.02
xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
clf=tree.DecisionTreeClassifier(max_depth=4)
clf=clf.fit(x, y)
Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(x[:,0],x[:,1],c=y) 
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.show()
 

你可能感兴趣的:(数据挖掘,决策树)