基于svm机器学习的手写数字识别

机器学习入门来说,手写数字识别是个很不错的练习项目

而我们这里基于svm练习我们的所学习的机器学习。而我们选择的训练集是MNIST,这个训练集量大,好用,有几万张纯手写28*28的数字图像,适合我们这些初学者进行练手使用。

相关的完整项目已经上传到 github 上,内有代码和训练集

网址为 https://github.com/grey-wood-wolf/Handwritten-digit-recognition

那么下面就开始我们的练手

首先我们将所需要用到的,或者可能用到的库给import,大部分库调用命令行就可以下载

import os
import struct
from datetime import datetime
from matplotlib import pyplot as plt
import numpy as np
from sklearn import  svm
from PIL import Image

之后我们就要加载我们的训练集

这里我们编写加载的函数

def load_mnist(path, kind='train'):
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind) 
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind) 

    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8)) 
        labels = np.fromfile(lbpath, dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) 
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels

这里将图片转换成numpy的形式保存,并且一个28*28的图片保存为1*784,便于之后的训练

训练集需要提前解压

之后我们编写我们的训练函数

def train(train_num):#train_num为训练的量,最多6个W
    X_train, y_train = load_mnist('.//dataset//MNIST//raw//', kind='train')  # 加载训练集
    #X = preprocessing.StandardScaler().fit_transform(X_train)
    X=X_train
    X_train = X[0:train_num]  # 训练60000张
    y_train = y_train[0:train_num]

    dt = datetime.now()#计时用的
    print('time is ' + dt.strftime('%Y-%m-%d %H:%M:%S'))

    #调用函数进行训练,参数可以自己调,效果可能更好
    model_svc = svm.SVC(kernel='rbf', gamma='scale')
    model_svc.fit(X_train, y_train)

    dt = datetime.now()
    print('time is ' + dt.strftime('%Y-%m-%d %H:%M:%S'))

    return model_svc

这里函数中调用fit()函数,将我们加载的训练集代入进去,就可以实现训练了,核函数我们采用rbf,具体的实现主要是寻找超平面,感兴趣的可以去自行学习。

之后会返回一个训练好的模型

之后就是测试函数

def test(model_svc, test_num):
    test_images, test_labels = load_mnist('.//dataset//MNIST//raw//', kind='t10k')  # 加载测试集
    #x = preprocessing.StandardScaler().fit_transform(test_images)
    x=test_images #测试图片
    x_test = x[0:test_num]
    y_test = test_labels[0:test_num]

    print(model_svc.score(x_test, y_test))  
    #return model_svc.score(x_test, y_test)
    return test_images, test_labels, x

其中score()函数就可以返回测试训练集带入模型中后,正确率是有多少

这里多返回了一个x,我以为没有返回test_images,可以忽略.......

最后编写我们的预测函数

def pred(model_svc, pred_num, test_images, test_labels, x):
    y_pred = model_svc.predict(x[9690 - pred_num:9690])  # 进行预测,能得到一个结果
    print(y_pred)

    X_show = test_images[9690 - pred_num:9690]
    #Y_show = test_labels[9690 - pred_num:9690]

    #打印图片看看效果
    for i in range(pred_num):
        x_show = X_show[i].reshape(28, 28)
        plt.subplot(1, pred_num, i + 1)
        plt.imshow(x_show, cmap=plt.cm.gray_r)
        plt.title(str(y_pred[i]))
        plt.axis('off')
    plt.show()

这里调用了predict()函数,可以对我们的的一些图片进行预测,并且返回预测结果,之后我们打印出来比对看效果。

最后我们调用这几个函数就可以实现我们的完整的手写数字识别的机器学习了

model = train(2000)#训练个数
test_images, test_labels, x = test(model,9900)
pred(model,9,test_images, test_labels, x)

效果如下:

基于svm机器学习的手写数字识别_第1张图片

 基于svm机器学习的手写数字识别_第2张图片

 我这里只训练了2000的量,所以效果不好,实际尝试中可以试试更多的,但时间也会成几何倍速增长。

这里我们自己也手画了一个手写数字给他识别:

#下列测试自己的图片(自己的数字效果不好,但用MNIST效果不错)
image_file = Image.open(".//mynum.png") # open colour image
image_file = image_file.resize((28,28))
image_file = image_file.convert('L') # convert image to black and white
image_file = np.array(image_file,dtype=np.uint8)
image_file = image_file.reshape(1,784)
mypred = model.predict(image_file)
print(mypred)
plt.imshow(Image.open(".//mynum.png"), cmap=plt.cm.gray_r)
plt.title(str(mypred[0]))
plt.show()

基于svm机器学习的手写数字识别_第3张图片

 但效果不好,这个对的也是我碰巧搞出来的,不过调调参数,或者,用自己手写个几千张的图片去训练,应该会好很多。(毕竟MNIST训练集的数字长得都很“漂亮”,没有什么奇奇怪怪的)

之后我也是闲的没事,想看看多大的训练量能达到一个不错的水平,所以直接分步长看训练效果了

scores = []
tra=[]
for i in range(2000,60000)[::2000]:
    model = train(i)
    scores += [test(model,1000)]
    tra +=[i]
plt.scatter(tra, scores)
plt.show()

效果是这个样子:

横坐标为训练量,纵坐标为正确率

基于svm机器学习的手写数字识别_第4张图片

 所以这里训练量在25000就可以了,之后的虽然还有提升,时间就太长了。

谢谢大家的观看,慢慢学习,一同奋斗。

你可能感兴趣的:(支持向量机,机器学习,人工智能)