scikit-learn/ID3手写数字识别

环境:python 3,scikit-learn 0.18

判定树是一个类似于流程图的树结构:其中,每个内部结点表示在一属性上的测试,
每个分支代表一个属性输出,而每个树叶结点代表类或类分布。树的顶层是根结点。
ID3算法根据的就是信息获取量(Information Gain):Gain(A) = Info(D) - Infor_A(D)

#coding:utf-8
"""
python 3 
sklearn 0.18
"""
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import input_data
import numpy as np
import pickle

mnist = input_data.read_data_sets('mnist/',one_hot=False)
x = mnist.train.images
y = mnist.train.labels
#采用交叉验证
train_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=0.2)
#训练一个DecisionTree分类器
clf = DecisionTreeClassifier(random_state=0,splitter='best',criterion='entropy')
clf.fit(train_data,train_labels)
predictions=[]
for i in range(1000):
    if i % 100 ==0:
        print('= = = = = = > > > > > >','epoch:',int(i/100))
    #将预测结果存入predictions
    output = clf.predict([mnist.test.images[i]])
    predictions.append(output)
print(confusion_matrix(mnist.test.labels[0:1000],predictions))
print(classification_report(mnist.test.labels[0:1000],np.array(predictions)))
print('test accuracy is:',accuracy_score(mnist.test.labels[0:1000],predictions))
with open('id3.pickle','wb') as f:
    pickle.dump(clf,f)

结果

对于mnist手写数字的识别,采用ID3算法的检测精度达到87%,决策树的缺点在于分类时,类别越多错误增加越快,而且决策树越深越容易出现overfitting
scikit-learn/ID3手写数字识别_第1张图片

你可能感兴趣的:(深度学习机器学习,ID3,决策树,sklearn,mnist)