利用SVM对手写字体进行识别并和随机森林对比

来自邹博机器学习课件,自己根据教学内容又做了小部分改变,但是部分问题还是不太清楚,故发表于次以供以后探讨。

'''
利用SVM进行手写字体识别
'''
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from PIL import Image
import matplotlib.pyplot as plt


if __name__ == '__main__':
    data_train = pd.read_csv('14.optdigits.tra', header=None)
    data_test = pd.read_csv('14.optdigits.tes', header=None)
    x_train = data_train.iloc[:, :-1]
    y_train = data_train.iloc[:, -1]
    x_test = data_test.iloc[:, :-1]
    y_test = data_test.iloc[:, -1]
    
    # 将ndarray转换图片
    x_image = x_train.values.reshape(-1, 8, 8).astype(np.uint8)
    x_label = y_train.values
    
    Image.fromarray(255-x_image[5]*15).save('./test.png')
    '''原始图片*15,不知道干嘛,反正这样做就会让图片变得清晰'''
    # 绘制出前16个图像
    plt.figure(figsize=(12, 12))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow(x_image[i], cmap=plt.cm.gray_r)
        plt.title('训练图片: %d' % x_label[i], fontsize=12)
    plt.tight_layout(1.2)
    plt.show()

    # SVM预测
    model = SVC(C=1, kernel='rbf', gamma=0.001, decision_function_shape='ovr')
    model.fit(x_train, y_train)
    y_hat= model.predict(x_test)
    print('SVM训练集上的正确率:', accuracy_score(y_train, model.predict(x_train)))
    print('SVM测试集上的正确率:', accuracy_score(y_test, y_hat))
    
    # 画出分类错误的几个图片
    x_test_image = x_test.values.reshape(-1, 8, 8).astype(np.uint8)
    error_ix = y_test[y_test != y_hat].index
    plt.figure(figsize=(16, 12))
    for i in range(30):
        plt.subplot(5, 6, i+1)
        image_ix = error_ix[i]
        plt.imshow(x_test_image[image_ix], cmap=plt.cm.gray_r)
        plt.title('正确为:%d, 分错为:%d' % (y_test[image_ix], y_hat[image_ix]))
    plt.tight_layout(1.2)
    plt.show()
    
    # 随机森林预测
    model = RandomForestClassifier(n_estimators=200, max_depth=5)
    model.fit(x_train, y_train)
    print('RF训练集上的正确率:', accuracy_score(y_train, model.predict(x_train)))
    print('RF测试集上的分类报告:', accuracy_score(y_test, model.predict(x_test)))

大部分代码都是对照视频敲的,感觉自己去实现还差得远。。。。

你可能感兴趣的:(自学)