Scikit-Learn学习笔记——手写数字(MNIST)探索

手写数字(MNIST)探索

#加载并可视化手写数字
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
digits = load_digits()
digits.images.shape


fig, axes = plt.subplots(10,10, figsize=(8, 8),subplot_kw={'xticks':[], 'yticks':[]},
                        gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(digits.target[i]),transform=ax.transAxes, color='green')

Scikit-Learn学习笔记——手写数字(MNIST)探索_第1张图片

#降维
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
iso.fit(digits.data)
data_projected = iso.transform(digits.data)

#可视化
plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,
            edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('Spectral', 10))
plt.colorbar(label='digit label', ticks=range(10))
plt.clim(-0.5, 9.5)

Scikit-Learn学习笔记——手写数字(MNIST)探索_第2张图片

#分类
from sklearn.cross_validation import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, 
                                                    random_state=0)
model = GaussianNB()
model.fit(x_train, y_train)
y_model = model.predict(x_test)
accuracy_score(y_model, y_test)

#输出结果
0.8333333333333334
#查看模型的混淆矩阵,可以看出模型哪些地方做的不够好
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(y_test, y_model)

sns.heatmap(mat, square=True, annot=True, cbar=False)
plt.xlabel('predicted value')
plt.ylabel('true value')

Scikit-Learn学习笔记——手写数字(MNIST)探索_第3张图片

#另外一种显示模型特征的直观方式是将样本画出来,然后将预测标签放在左下角,用绿色表示预测正确,用红色表示预测错误
fig, axes = plt.subplots(10, 10, figsize=(8,8), subplot_kw={'xticks':[],'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))
test_images = x_test.reshape(-1, 8, 8)
for i, ax in enumerate(axes.flat):
    ax.imshow(test_images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(y_model[i]),transform=ax.transAxes,
            color='green' if (y_test[i] == y_model[i]) else 'red')

Scikit-Learn学习笔记——手写数字(MNIST)探索_第4张图片

你可能感兴趣的:(python,学习笔记,机器学习,数据科学)