安装过程参照这篇博客(侵删)
matlab中的toolbox可以直接在它的gui上下载,所以缺什么包就安装什么就好了。
这是下载地址,下载完过后是四个压缩包,分别对应训练、测试数据的图片和标签。这个网站可能下载会比较慢,可以选择从云盘下载。
这里我选择用foldernames作为labelsource,所以就建立以对应标签名为label的文件夹,将图片放进去,可以用python写一个脚本来处理,如果是从云盘下载的mnist.pkl,可以用以下代码来处理,将该脚本与mnist.pkl放在同一目录下运行即可生成图片。
import pickle
import os
from PIL import Image
data_file='./mnist.pkl'
with open(data_file, 'rb') as f:
dataset=pickle.load (f)
for i in range(10):
os.makedirs('./data\\test_data\\'+str(i))
os.makedirs('./data\\train_data\\'+str(i))
dataset['test_img']=dataset['test_img'].reshape(-1,28,28)
dataset['train_img']=dataset['train_img'].reshape(-1,28,28)
for i in range(len(dataset['train_img'])):
im = Image.fromarray(dataset['train_img'][i])
im.save('./data\\train_data\\'+str(dataset['train_label'][i])+'\\pic'+str(i)+'.jpg')
if i%1000==0:
print(str(i)+' pics finished!')
print('train_data finished!')
for i in range(len(dataset['test_img'])):
im = Image.fromarray(dataset['test_img'][i])
im.save('./data\\test_data\\'+str(dataset['test_label'][i])+'\\pic'+str(i)+'.jpg')
if i%1000==0:
print(str(i)+' pics finished!')
print('test_data finished!')
我所理解的迁移学习就是将别人训练好的神经网络某些层进行修改,使它适应某项特定的任务。因为它具有先前的学习经验,所以在学习新东西时要比从头开始学更快。如果我的理解有什么不对的,还请指教,毕竟matlab官网上是这么教的。此处我采用了alexnet,将它的最后几层进行修改,然后再对训练数据集进行训练。在matlab中,关于神经网络的库被高度封装,所以写起来很简便。代码如下:
clear
load pathToImages
net=alexnet;
datTrain=imageDatastore(strcat(pathToImages,'train_data\'),'IncludeSubfolders',true,'LabelSource','foldernames');
datTest=imageDatastore(strcat(pathToImages,'test_data\'),'IncludeSubfolders',true,'LabelSource','foldernames');
%导入数据
layers=net.Layers;
fc=fullyConnectedLayer(10);
layers(23)=fc;
layers(25)=classificationLayer;
%第23层原本是一千个神经元,用于识别一千中图像,现在改为10个神经元的图像,并且更新最后一层
opts=trainingOptions('sgdm','InitialLearnRate',0.001)
%设置参数,使用sgdm优化器,学习率设为0.001
trainimgs=augmentedImageDatastore([227,227],datTrain,'ColorPreprocessing','gray2rgb');
testimgs=augmentedImageDatastore([227,227],datTest,'ColorPreprocessing','gray2rgb');
%数据增强:将原本的数据类型(27*27)改为(227*227*3)的图像
[newnet,info]=trainNetwork(trainimgs,layers,opts)
%开始训练,保存训练出来的网络以及训练信息
plot(info.TrainingLoss);
%输出loss图像
actul=datTest.Labels;
preds=classify(newnet,testimgs);
numCorrect=nnz(actul==preds)
%nnz->Count non-zero elements in an array
fracCorrect=numCorrect/numel(preds)
%计算准确率
confusionchart(actul,preds)
%混淆矩阵
我花了大约70分钟训练完毕,在测试集上的准确率达到了0.9958,应该算是一个不错的准确率了。
虽然准确率还是不错,不过我还是想要看看到底怎样的图片会被分类错误,于是我又写了下面的脚本来查看分类错误的图片。
cnt=0;
for i=1:10000
if preds(i)~=actul(i)
cnt=cnt+1;
figure(cnt)
tmp=datTest.Files(i);
tmp=tmp{1};
imshow(tmp)
str1=cellstr(preds(i));
str2=cellstr(actul(i));
title(strcat(str1,' ',str2));
end
end
作为matlab小白,我写的代码比较丑陋,还请多多包涵。
之后我得到了一些类似于这样的图片。
每个图片上面的数字,左边是预测结果,右边是实际结果。从中可以看到,这些图片即使是靠肉眼也难以分辨准确,(也许是因为写这些数字的人书写太差了)。因此也没有必要再追求更高的准确率了。