用Jupyter——鸢尾花的分类

用Jupyter——鸢尾花的分类_第1张图片

 

加载鸢尾花数据集

[1]:

from  sklearn.datasets import load_iris

[2]:

iris=load_iris()

[3]:

dir(iris)

[3]:

['DESCR',
 'data',
 'data_module',
 'feature_names',
 'filename',
 'frame',
 'target',
 'target_names']

[4]:

x=iris.data
y=iris.target

[5]:

print(x.shape)
(150, 4)

[6]:

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

2.按照7:3切分训练集和测试集,种子设置为888

[7]:

X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)

[8]:

print(X_train.shape,X_test.shape)
(105, 4) (45, 4)

[9]:

print(y_train.shape)
(105,)

3.利用训练集训练ID3决策树,种子值设置同上

[12]:

clf=DecisionTreeClassifier(criterion='entropy',random_state=888)
clf=clf.fit(X_train,y_train)

4.测试模型的分类性能

[14]:

 
  
clf.score(X_test,y_test)

5.绘制决策树

[15]:

 
  
from sklearn import tree

[16]:

 
  
print(tree.export_text(clf))
|--- feature_3 <= 0.75
|   |--- class: 0
|--- feature_3 >  0.75
|   |--- feature_2 <= 4.75
|   |   |--- class: 1
|   |--- feature_2 >  4.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- feature_1 <= 3.10
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- feature_1 >  3.10
|   |   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.95
|   |   |   |--- class: 2

[17]:

 
  
import matplotlib.pyplot as plt

[64]:

fig,ax=plt.subplots(figsize=(10,10))
tree.plot_tree(clf,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)

[64]:

[Text(0.3333333333333333, 0.9166666666666666, 'petal width (cm) <= 0.75\nentropy = 1.582\nsamples = 105\nvalue = [34, 33, 38]\nclass = virginica'),
 Text(0.16666666666666666, 0.75, 'entropy = 0.0\nsamples = 34\nvalue = [34, 0, 0]\nclass = setosa'),
 Text(0.5, 0.75, 'petal length (cm) <= 4.75\nentropy = 0.996\nsamples = 71\nvalue = [0, 33, 38]\nclass = virginica'),
 Text(0.3333333333333333, 0.5833333333333334, 'entropy = 0.0\nsamples = 30\nvalue = [0, 30, 0]\nclass = versicolor'),
 Text(0.6666666666666666, 0.5833333333333334, 'petal length (cm) <= 4.95\nentropy = 0.378\nsamples = 41\nvalue = [0, 3, 38]\nclass = virginica'),
 Text(0.5, 0.4166666666666667, 'petal width (cm) <= 1.65\nentropy = 0.954\nsamples = 8\nvalue = [0, 3, 5]\nclass = virginica'),
 Text(0.3333333333333333, 0.25, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = versicolor'),
 Text(0.6666666666666666, 0.25, 'sepal width (cm) <= 3.1\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]\nclass = virginica'),
 Text(0.5, 0.08333333333333333, 'entropy = 0.0\nsamples = 5\nvalue = [0, 0, 5]\nclass = virginica'),
 Text(0.8333333333333334, 0.08333333333333333, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = versicolor'),
 Text(0.8333333333333334, 0.4166666666666667, 'entropy = 0.0\nsamples = 33\nvalue = [0, 0, 33]\nclass = virginica')]

用Jupyter——鸢尾花的分类_第2张图片

 完整代码

from sklearn.datasets import load_iris

iris=load_iris()

dir(iris)

x=iris.data

y=iris.target

print(x.shape)

from sklearn.model_selection import train_test_split

from sklearn.tree import DecisionTreeClassifier

X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=888)

print(X_train.shape,X_test.shape)

print(y_train.shape)

clf=DecisionTreeClassifier(criterion='entropy',random_state=888)

clf=clf.fit(X_train,y_train)

clf.score(X_test,y_test)

from sklearn import tree

print(tree.export_text(clf))

import matplotlib.pyplot as plt

fig,ax=plt.subplots(figsize=(10,10))

tree.plot_tree(clf ,feature_names=iris.feature_names,class_names=iris.target_names,filled=True)

你可能感兴趣的:(大数据人工智能,jupyter,分类,机器学习)