mnist数据集svm python_应用SVM对MNIST数据集进行分类

MNIST是机器学习领域十分经典的一个手写数字数据集,共60000张训练图像,10000张测试图像,图像大小为28*28.

将下载下来的压缩包解压后放到源代码所在的文件夹下即可。

SVM分类MNIST的源代码如下:

from sklearn import svm

import numpy as np

from time import time

from sklearn.metrics import accuracy_score

from struct import unpack

from sklearn.model_selection import GridSearchCV

def readimage(path):

with open(path, 'rb') as f:

magic, num, rows, cols = unpack('>4I', f.read(16))

img = np.fromfile(f, dtype=np.uint8).reshape(num, 784)

return img

def readlabel(path):

with open(path, 'rb') as f:

magic, num = unpack('>2I', f.read(8))

lab = np.fromfile(f, dtype=np.uint8)

return lab

def main():

train_data = readimage("train-images.idx3-ubyte")

train_label = readlabel("train-labels.idx1-ubyte")

test_data = readimage("t10k-images.idx3-ubyte")

test_label = readlabel("t10k-labels.idx1-ubyte")

svc=svm.SVC()

parameters = {'kernel':['rbf'], 'C':[1]}

print("Train...")

clf=GridSearchCV(svc,parameters,n_jobs=-1)

start = time()

clf.fit(train_data, train_label)

end = time()

t = end - start

print('Train:%dmin%.3fsec' % (t//60, t - 60 * (t//60)))

prediction = clf.predict(test_data)

print("accuracy: ", accuracy_score(prediction, test_label))

accurate=[0]*10

sumall=[0]*10

i=0

while i

sumall[test_label[i]]+=1

if prediction[i]==test_label[i]:

accurate[test_label[i]]+=1

i+=1

print("分类正确的:",accurate)

print("总的测试标签:",sumall)

if __name__ == '__main__':

main()

程序通过readimage和readlabel函数读入数据后创建svm分类器,并用parameter添加相应的参数,这里使用GridSearchCV将参数作为输入优化网络,这里输入的parameter对应分类器唯一,可进行添加以达到优化参数的目的,代码中使用GridSearchCV的主要目的是引入n_jobs让cpu进行多线程处理,n_jobs=-1时程序的并行数将和cpu的核数一致,从而极大的加速程序的运行。在i5-8300H的四核CPU中训练时间为26min。

源代码训练时的正确率如下:

mnist数据集svm python_应用SVM对MNIST数据集进行分类_第1张图片

欢迎评论区交流。

本文地址:https://blog.csdn.net/qq_43160985/article/details/107675241

希望与广大网友互动??

点此进行留言吧!

你可能感兴趣的:(mnist数据集svm,python)