Scikit-Learn学习笔记——用随机森林识别手写数字

用随机森林识别手写数字

from sklearn.datasets import load_digits
digits = load_digits()

#显示前几个数字图像
fig = plt.figure(figsize=(6,6))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(64):
    ax = fig.add_subplot(8,8,i+1, xticks=[], yticks=[])
    ax.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')
    ax.text(0,7,str(digits.target[i]))

Scikit-Learn学习笔记——用随机森林识别手写数字_第1张图片

#用随机森林快速对数字进行分类
from sklearn.cross_validation import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, random_state=0)
model = RandomForestClassifier(n_estimators=1000)
model.fit(x_train, y_train)
ypre = model.predict(x_test)

#查看分类器的分类结果报告
from sklearn import metrics
print(metrics.classification_report(ypre, y_test))

#输出结果:
             precision    recall  f1-score   support

          0       1.00      0.97      0.99        38
          1       0.98      0.98      0.98        43
          2       0.95      1.00      0.98        42
          3       0.98      0.96      0.97        46
          4       0.97      1.00      0.99        37
          5       0.98      0.96      0.97        49
          6       1.00      1.00      1.00        52
          7       1.00      0.96      0.98        50
          8       0.94      0.98      0.96        46
          9       0.98      0.98      0.98        47

avg / total       0.98      0.98      0.98       450
#画出混淆矩阵
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(y_test, ypre)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False)
plt.xlabel('true label')
plt.ylabel('predicted label')

Scikit-Learn学习笔记——用随机森林识别手写数字_第2张图片

我们发现,用一个简单、未调优的随机森林对手写数字进行分类,就可以取得非常好的分类准确率。

你可能感兴趣的:(python,学习笔记,机器学习)