Python 09 决策树分类边界

Q:采用 Iris 数据集的前两个属性和前100个数据集,构建决策树,并画出类似于书籍图4.11的分类边界
  • 构建决策树
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import plot_tree


plt.rcParams['font.sans-serif']=['SimHei'] #显示中文标签
plt.rcParams['axes.unicode_minus']=False

iris = load_iris()
X = iris.data[:,:2]
Y = iris.target
# X_0,Y = load_iris(return_X_y=True)

clf = DecisionTreeClassifier(max_depth=6)
# clf = DecisionTreeRegressor()
clf = clf.fit(X,Y)
plt.figure(figsize=(12,12))  # set plot size (denoted in inches)
plot_tree(clf,filled='True',fontsize=6)
plt.show()
Python 09 决策树分类边界_第1张图片 标题

  • 定义函数:画出数据点和边界
def border_of_classifier(sklearn_cl,x,y):

    # param sklearn_cl:skearn的分类器
    # param x: np.array
    # param y: np.array

    ## 1生成网格数据
    x_min,y_min = x.min(axis=0)- 1
    x_max,y_max = x.max(axis= 0) +1
    # axis=0
    # 沿着行(rows)的方向跨列
    # axis=1
    # 沿着列(cols)的方向跨行

    #利用一组网格数据求出方程的值,然后把边界画出来。
    x_values, y_values = np.meshgrid(np.arange(x_min, x_max, 0.01 ),
    np.arange(y_min, y_max, 0.01))

    #计算出分类器对所有数据点的分类结果生成网格采样
    mesh_output = sklearn_cl.predict(np.c_[x_values.ravel(), y_values.ravel()])
    # np.c_[array1, array2]  # 把数组array1和数组array2配对后输出
    # ravel()方法将数组维度拉成一维数组

    #数组维度变形
    mesh_output = mesh_output.reshape(x_values.shape)

    #plt.pcolormesh()会根据mesh_output的结果自动的在cmap中选择颜色
    plt.contourf(x_values, y_values, mesh_output ,cmap=plt.cm.Spectral)  # 等高线
    # plt.pcolormesh(x_values, y_values, mesh_output ,cmap = 'rainbow')
    plt.scatter(X[Y==0, 0], X[Y == 0, 1])
    plt.scatter(X[Y==1, 0], X[Y == 1, 1])
    plt.scatter(X[Y==2, 0], X[Y == 2, 1])
    plt.xlim(x_values. min(), x_values.max())
    plt.ylim(y_values.min(), y_values.max())
    # plt.xlabel('SepaLength(单位cm)')
    # plt.ylabel('SepalWidth(单位cm)')
    # plt.title('Iris 数据集(前两个属性)')
    
    #设置x轴和y轴
    plt.xticks((np. arange(np.ceil(min(x[:, 0])- 1), np.ceil(max(x[:, 0])+ 1), 1.0)))
    plt.yticks((np. arange(np.ceil(min(x[:,1])- 1), np.ceil(max(x[:,1])+ 1),1.0)))
    plt.show()
  • 调用函数,绘图
X = np.array(X)
Y = np.array(Y)
border_of_classifier(clf, X, Y)

# plt.scatter(X[Y == 0, 0], X[Y == 0, 1], c='red', marker='o', label='山鸢尾')
# plt.scatter(X[Y == 1, 0], X[Y == 1, 1], c='green', marker='+', label='变色鸢尾花')
# plt.scatter(X[Y == 2, 0], X[Y == 2, 1], c='blue', marker='*', label='维吉尼亚鸢尾')

 Python 09 决策树分类边界_第2张图片

 参考:

Sklearn plot_tree图太小 - 问答 - 云+社区 - 腾讯云 (tencent.com)

机器学习技巧_绘制简单分类器边界(决策树SVM) - 百度文库 (baidu.com)

你可能感兴趣的:(Python和机器学习,决策树,python,分类)