【ML-SVM案例学习】会有十种SVM案例,供大家用来学习。本章实现svm实现手写数字识别。
提示:以下是本篇文章正文内容,下面案例可供参考
代码如下(示例):
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics
代码如下(示例):
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
digits = datasets.load_digits()
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
data.shape
classifier = svm.SVC(gamma=0.001) # 默认是rbf
# from sklearn.neighbors import KNeighborsClassifier
# classifier = KNeighborsClassifier(n_neighbors=9, weights='distance')
## 使用二分之一的数据进行模型训练
##取前一半数据训练,后一半数据测试
classifier.fit(data[:int(n_samples / 2)], digits.target[:int(n_samples / 2)])
## 后一半数据作为测试集
expected = digits.target[int(n_samples/2):] # y_test
predicted = classifier.predict(data[int(n_samples / 2):])##y_predicted
## 计算准确率
print("分类器%s的分类效果:\n%s\n"
% (classifier, metrics.classification_report(expected, predicted)))
## 生成一个分类报告classification_report
print("混淆矩阵为:\n%s" % metrics.confusion_matrix(expected, predicted))
## 生成混淆矩阵
print("score_svm:\n%f" % classifier.score(data[int(n_samples / 2):], digits.target[int(n_samples / 2):]))
plt.figure(facecolor='gray', figsize=(12,5))
images_and_predictions = list(zip(digits.images[int(n_samples / 2):][expected != predicted], expected[expected != predicted], predicted[expected != predicted]))
for index,(image,expection, prediction) in enumerate(images_and_predictions[:5]):
plt.subplot(2, 5, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') # 把cmap中的灰度值与image矩阵对应,并填充
plt.title(u'预测值/实际值:%i/%i' % (prediction, expection))
images_and_predictions = list(zip(digits.images[int(n_samples / 2):][expected == predicted], expected[expected == predicted], predicted[expected == predicted]))
for index, (image,expection, prediction) in enumerate(images_and_predictions[:5]):
plt.subplot(2, 5, index + 6)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title(u'预测值/实际值:%i/%i' % (prediction, expection))
plt.subplots_adjust(.04, .02, .97, .94, .09, .2)
plt.show()
以上就是今天要讲的内容,本文仅仅简单介绍了svm实现手写数字识别,仅供参考学习。