实现一个程序,需要先采集数据并且尽可能多的采集不同的数据(防止偶然性,使得数据具有代表性),然后对数据进行标记。
选择特征的直观方法:直接使用图片的每个像素点作为一个特征。 数据保存为样本个数×特征个数格式的array对象。
在采集数据完后,为了减少计算量,也为了模型的稳定性,我们需要对数据进行数据清洗,即把采集到的、不适合用来做机器学习训练的数据进行预处理,从而转化为适合机器学习的数据。
对于不同的数据集,选择不同的模型有不同的效率。因此在选择模型要考虑很多的因素,从众多的因素中找到一个最适合模型,同时这个模型要使结果模拟评分达到最高。
在进行模型训练之前,要将数据集划分为训练数据集和测试数据集,再利用划分好的数据集进行模型训练,最后得到我们训练出来的模型参数
模型测试的直观方法:用训练出来的模型预测测试数据集,然后将预测出来的结果与真正的结果进行比较,最后比较出来的结果即为模型的准确度。
当我们训练出一个满意的模型后可以将模型进行保存,这样当我们再一次需要使用此模型时可以直接利用此模型进行预测,不用再一次进行模型训练。
1.数据采集与标记
#导入库
import inline as inline
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
'''sk-learn库中自带了一些数据集,以上是手写数字识别图片的数据'''
#导入sklearn库中datasets模块
from sklearn import datasets
#利用datasets模块中的函数load_digits()进行加载
digits=datasets.load_digits()
#把数据所代表的图片显示出来
images_and_labless=list(zip(digits.images,digits.target))
plt.figure(figsize=(8,6))
for index,(image,lable) in enumerate(images_and_labels[:8]):
plt.subplot(2,4,index+1)
plt.axis('off')
plt.imshow(image,cmap=plt.cm.gray_r,interpolation='nearest')
plt.title('Digit:%i'%label,fontsize=20)
#将数据保存为[样本个数*特征个数]格式的array对象的格式进行输出
#数据保存在了digits.data文件中
print("shape of raw image data:{0}".format(digits.image.shape))
print("shape of data:{0}".format(digits.data.shape))
# 把数据分成训练数据集和测试数据集(此处将数据集的百分之二十作为测试数据集)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);
# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
# 使用训练数据集Xtrain和Ytrain来训练模型
clf.fit(Xtrain, Ytrain);
'''
sklearn.metrics.accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)
normalize:默认值为True,返回正确分类的比例;如果为False,返回正确分类的样本数
'''
# 评估模型的准确度(此处默认为true,直接返回正确的比例,也就是模型的准确度)
from sklearn.metrics import accuracy_score
# predict是训练后返回预测结果,是标签值。
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)
"""
将测试数据集里的部分图片显示出来
图片的左下角显示预测值,右下角显示真实值
"""
# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])