使用matlab在alexnet上迁移学习训练mnist

使用matlab在alexnet上迁移学习训练mnist

    • 安装matlab r2018b
    • 下载mnist数据集
    • 处理mnist数据集
    • 迁移学习
    • 查看分类错误的数据

安装matlab r2018b

安装过程参照这篇博客(侵删)
matlab中的toolbox可以直接在它的gui上下载,所以缺什么包就安装什么就好了。

下载mnist数据集

这是下载地址,下载完过后是四个压缩包,分别对应训练、测试数据的图片和标签。这个网站可能下载会比较慢,可以选择从云盘下载。

处理mnist数据集

这里我选择用foldernames作为labelsource,所以就建立以对应标签名为label的文件夹,将图片放进去,可以用python写一个脚本来处理,如果是从云盘下载的mnist.pkl,可以用以下代码来处理,将该脚本与mnist.pkl放在同一目录下运行即可生成图片。

import pickle
import os
from PIL import Image
data_file='./mnist.pkl'
with open(data_file, 'rb') as f:
	dataset=pickle.load	(f)

for i in range(10):
	os.makedirs('./data\\test_data\\'+str(i))
	os.makedirs('./data\\train_data\\'+str(i))
dataset['test_img']=dataset['test_img'].reshape(-1,28,28)
dataset['train_img']=dataset['train_img'].reshape(-1,28,28)

for i in range(len(dataset['train_img'])):
	im = Image.fromarray(dataset['train_img'][i])
	im.save('./data\\train_data\\'+str(dataset['train_label'][i])+'\\pic'+str(i)+'.jpg')
	if i%1000==0:
		print(str(i)+' pics finished!')

print('train_data finished!')

for i in range(len(dataset['test_img'])):
	im = Image.fromarray(dataset['test_img'][i])
	im.save('./data\\test_data\\'+str(dataset['test_label'][i])+'\\pic'+str(i)+'.jpg')
	if i%1000==0:
		print(str(i)+' pics finished!')
		
print('test_data finished!')

迁移学习

我所理解的迁移学习就是将别人训练好的神经网络某些层进行修改,使它适应某项特定的任务。因为它具有先前的学习经验,所以在学习新东西时要比从头开始学更快。如果我的理解有什么不对的,还请指教,毕竟matlab官网上是这么教的。此处我采用了alexnet,将它的最后几层进行修改,然后再对训练数据集进行训练。在matlab中,关于神经网络的库被高度封装,所以写起来很简便。代码如下:

clear
load pathToImages
net=alexnet;
datTrain=imageDatastore(strcat(pathToImages,'train_data\'),'IncludeSubfolders',true,'LabelSource','foldernames');
datTest=imageDatastore(strcat(pathToImages,'test_data\'),'IncludeSubfolders',true,'LabelSource','foldernames');
%导入数据
layers=net.Layers;
fc=fullyConnectedLayer(10);
layers(23)=fc;
layers(25)=classificationLayer;
%第23层原本是一千个神经元,用于识别一千中图像,现在改为10个神经元的图像,并且更新最后一层
opts=trainingOptions('sgdm','InitialLearnRate',0.001)
%设置参数,使用sgdm优化器,学习率设为0.001
trainimgs=augmentedImageDatastore([227,227],datTrain,'ColorPreprocessing','gray2rgb');
testimgs=augmentedImageDatastore([227,227],datTest,'ColorPreprocessing','gray2rgb');
%数据增强:将原本的数据类型(27*27)改为(227*227*3)的图像
[newnet,info]=trainNetwork(trainimgs,layers,opts)
%开始训练,保存训练出来的网络以及训练信息
plot(info.TrainingLoss);
%输出loss图像
actul=datTest.Labels;
preds=classify(newnet,testimgs);

numCorrect=nnz(actul==preds)
%nnz->Count non-zero elements in an array

fracCorrect=numCorrect/numel(preds)
%计算准确率
confusionchart(actul,preds)
%混淆矩阵

我花了大约70分钟训练完毕,在测试集上的准确率达到了0.9958,应该算是一个不错的准确率了。

查看分类错误的数据

虽然准确率还是不错,不过我还是想要看看到底怎样的图片会被分类错误,于是我又写了下面的脚本来查看分类错误的图片。

cnt=0;
for i=1:10000
	if preds(i)~=actul(i)
		cnt=cnt+1;
		figure(cnt)
		tmp=datTest.Files(i);
		tmp=tmp{1};
		imshow(tmp)
		str1=cellstr(preds(i));
		str2=cellstr(actul(i));
		title(strcat(str1,' ',str2));
	end
end

作为matlab小白,我写的代码比较丑陋,还请多多包涵。

之后我得到了一些类似于这样的图片。
使用matlab在alexnet上迁移学习训练mnist_第1张图片
使用matlab在alexnet上迁移学习训练mnist_第2张图片
使用matlab在alexnet上迁移学习训练mnist_第3张图片
使用matlab在alexnet上迁移学习训练mnist_第4张图片
使用matlab在alexnet上迁移学习训练mnist_第5张图片
每个图片上面的数字,左边是预测结果,右边是实际结果。从中可以看到,这些图片即使是靠肉眼也难以分辨准确,(也许是因为写这些数字的人书写太差了)。因此也没有必要再追求更高的准确率了。

你可能感兴趣的:(深度学习)