Q:采用 Iris 数据集的前两个属性和前100个数据集,构建决策树,并画出类似于书籍图4.11的分类边界
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import plot_tree
plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
plt.rcParams['axes.unicode_minus']=False
iris = load_iris()
X = iris.data[:,:2]
Y = iris.target
# X_0,Y = load_iris(return_X_y=True)
clf = DecisionTreeClassifier(max_depth=6)
# clf = DecisionTreeRegressor()
clf = clf.fit(X,Y)
plt.figure(figsize=(12,12)) # set plot size (denoted in inches)
plot_tree(clf,filled='True',fontsize=6)
plt.show()
def border_of_classifier(sklearn_cl,x,y):
# param sklearn_cl:skearn的分类器
# param x: np.array
# param y: np.array
## 1生成网格数据
x_min,y_min = x.min(axis=0)- 1
x_max,y_max = x.max(axis= 0) +1
# axis=0
# 沿着行(rows)的方向跨列
# axis=1
# 沿着列(cols)的方向跨行
#利用一组网格数据求出方程的值,然后把边界画出来。
x_values, y_values = np.meshgrid(np.arange(x_min, x_max, 0.01 ),
np.arange(y_min, y_max, 0.01))
#计算出分类器对所有数据点的分类结果生成网格采样
mesh_output = sklearn_cl.predict(np.c_[x_values.ravel(), y_values.ravel()])
# np.c_[array1, array2] # 把数组array1和数组array2配对后输出
# ravel()方法将数组维度拉成一维数组
#数组维度变形
mesh_output = mesh_output.reshape(x_values.shape)
#plt.pcolormesh()会根据mesh_output的结果自动的在cmap中选择颜色
plt.contourf(x_values, y_values, mesh_output ,cmap=plt.cm.Spectral) # 等高线
# plt.pcolormesh(x_values, y_values, mesh_output ,cmap = 'rainbow')
plt.scatter(X[Y==0, 0], X[Y == 0, 1])
plt.scatter(X[Y==1, 0], X[Y == 1, 1])
plt.scatter(X[Y==2, 0], X[Y == 2, 1])
plt.xlim(x_values. min(), x_values.max())
plt.ylim(y_values.min(), y_values.max())
# plt.xlabel('SepaLength(单位cm)')
# plt.ylabel('SepalWidth(单位cm)')
# plt.title('Iris 数据集(前两个属性)')
#设置x轴和y轴
plt.xticks((np. arange(np.ceil(min(x[:, 0])- 1), np.ceil(max(x[:, 0])+ 1), 1.0)))
plt.yticks((np. arange(np.ceil(min(x[:,1])- 1), np.ceil(max(x[:,1])+ 1),1.0)))
plt.show()
X = np.array(X)
Y = np.array(Y)
border_of_classifier(clf, X, Y)
# plt.scatter(X[Y == 0, 0], X[Y == 0, 1], c='red', marker='o', label='山鸢尾')
# plt.scatter(X[Y == 1, 0], X[Y == 1, 1], c='green', marker='+', label='变色鸢尾花')
# plt.scatter(X[Y == 2, 0], X[Y == 2, 1], c='blue', marker='*', label='维吉尼亚鸢尾')
参考:
Sklearn plot_tree图太小 - 问答 - 云+社区 - 腾讯云 (tencent.com)
机器学习技巧_绘制简单分类器边界(决策树SVM) - 百度文库 (baidu.com)