[1]:
from sklearn.datasets import load_iris
[2]:
iris=load_iris()
[3]:
dir(iris)
[3]:
['DESCR', 'data', 'data_module', 'feature_names', 'filename', 'frame', 'target', 'target_names']
[4]:
x=iris.data
y=iris.target
[5]:
print(x.shape)
(150, 4)
[6]:
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
[7]:
X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)
[8]:
print(X_train.shape,X_test.shape)
(105, 4) (45, 4)
[9]:
print(y_train.shape)
(105,)
[12]:
clf=DecisionTreeClassifier(criterion='entropy',random_state=888)
clf=clf.fit(X_train,y_train)
[14]:
clf.score(X_test,y_test)
[15]:
from sklearn import tree
[16]:
print(tree.export_text(clf))
|--- feature_3 <= 0.75 | |--- class: 0 |--- feature_3 > 0.75 | |--- feature_2 <= 4.75 | | |--- class: 1 | |--- feature_2 > 4.75 | | |--- feature_2 <= 4.95 | | | |--- feature_3 <= 1.65 | | | | |--- class: 1 | | | |--- feature_3 > 1.65 | | | | |--- feature_1 <= 3.10 | | | | | |--- class: 2 | | | | |--- feature_1 > 3.10 | | | | | |--- class: 1 | | |--- feature_2 > 4.95 | | | |--- class: 2
[17]:
import matplotlib.pyplot as plt
[64]:
fig,ax=plt.subplots(figsize=(10,10))
tree.plot_tree(clf,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)
[64]:
[Text(0.3333333333333333, 0.9166666666666666, 'petal width (cm) <= 0.75\nentropy = 1.582\nsamples = 105\nvalue = [34, 33, 38]\nclass = virginica'), Text(0.16666666666666666, 0.75, 'entropy = 0.0\nsamples = 34\nvalue = [34, 0, 0]\nclass = setosa'), Text(0.5, 0.75, 'petal length (cm) <= 4.75\nentropy = 0.996\nsamples = 71\nvalue = [0, 33, 38]\nclass = virginica'), Text(0.3333333333333333, 0.5833333333333334, 'entropy = 0.0\nsamples = 30\nvalue = [0, 30, 0]\nclass = versicolor'), Text(0.6666666666666666, 0.5833333333333334, 'petal length (cm) <= 4.95\nentropy = 0.378\nsamples = 41\nvalue = [0, 3, 38]\nclass = virginica'), Text(0.5, 0.4166666666666667, 'petal width (cm) <= 1.65\nentropy = 0.954\nsamples = 8\nvalue = [0, 3, 5]\nclass = virginica'), Text(0.3333333333333333, 0.25, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor'), Text(0.6666666666666666, 0.25, 'sepal width (cm) <= 3.1\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]\nclass = virginica'), Text(0.5, 0.08333333333333333, 'entropy = 0.0\nsamples = 5\nvalue = [0, 0, 5]\nclass = virginica'), Text(0.8333333333333334, 0.08333333333333333, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor'), Text(0.8333333333333334, 0.4166666666666667, 'entropy = 0.0\nsamples = 33\nvalue = [0, 0, 33]\nclass = virginica')]
完整代码
from sklearn.datasets import load_iris
iris=load_iris()
dir(iris)
x=iris.data
y=iris.target
print(x.shape)
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)
print(X_train.shape,X_test.shape)
print(y_train.shape)
clf=DecisionTreeClassifier(criterion='entropy',random_state=888)
clf=clf.fit(X_train,y_train)
clf.score(X_test,y_test)
from sklearn import tree
print(tree.export_text(clf))
import matplotlib.pyplot as plt
fig,ax=plt.subplots(figsize=(10,10))
tree.plot_tree(clf ,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)