minst 项目_Update2

github:https://github.com/wuzy361/mnist_homework_project
github上显示上次更新已经是三周前,这个项目搁置了很久了。
在第一次更新中,主要工作是把数据集的接口写好了,直接把用二进制文件保存的数据集转化成python能直接处理的ndarray(numpy中的数组)。
其实有了规定格式的数据集后,使用sklearn库,就能很方便的做一些机器学习的工作了。以下是探索数据进行机器学习的一些经验:

1,numpy.squeeze()

关于这个函数最早是在http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 里看到的,里面的变量虽然不是numpy类型的而是torch.FloatTensor类型的。但是二者非常类似,方法名也一样。numpy.squezze()是这样的:

minst 项目_Update2_第1张图片
Paste_Image.png

作用是把一维从数组里移除,不如shape本来是(1,2,3),代表一个三维张量,但其实和一个二维矩阵内容是一样的。混用(1,2,3)和(2,3)可能导致程序产生警告,甚至是错误,所以应该使用squeeze处理矩阵。

minst 项目_Update2_第2张图片
Paste_Image.png

该代码应该改成:

minst 项目_Update2_第3张图片
Paste_Image.png

2,navie bayes 和svm

navie bayes 分类器应该是最简单的分类器了,试验结果是这样的:

Paste_Image.png

对于60000个训练数据集和1000个测试数据集来说,naive bayes用时非常短,但准确度很低,只有55.58%,毕竟NAIVE啊。

svm就不一样了,早就直到svm非常慢,在这不算小的数据集下,svm太慢了,所以我就先用pca降维了。
把原来的28×28 = 784维的数据降成80维,之后再训练:

Paste_Image.png

精确度很高,到达了98.18%,

minst 项目_Update2_第4张图片
Paste_Image.png

跟周志华的 论文的对比实验接近,稍微差一点的是由于使用了pca,不可避免有点数据损失。

你可能感兴趣的:(minst 项目_Update2)