1.机器学习之(4) Handwritten Digits Data Set初体验

载入数据

import numpy as np
import matplotlib 
import matplotlib.pyplot as plt
from sklearn import datasets
digits = datasets.load_digits()

得到数据集中的数据

X = digits.data
y = digits.target

可视化一下,二进制图像显示

随便选一个样本

some_digit = X[111]
some_digit_image = some_digit.reshape(8,8)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary)
plt.show()

1.机器学习之(4) Handwritten Digits Data Set初体验_第1张图片

y[111]

输出为4。

调用sklearn库

数据预处理

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

创建kNN模型并训练

from sklearn.neighbors import KNeighborsClassifier
kNN_classifier = KNeighborsClassifier(n_neighbors=3)
kNN_classifier.fit(X_train, y_train)

测试模型准确率

kNN_classifier.score(X_test,y_test)

输出

0.9861111111111112

你可能感兴趣的:(机器学习)