上一章,我们学习了利用逻辑回归算法处理复杂二分类问题的方法,这一章,我们通过手写数字识别的例子来介绍逻辑回归算法对多分类问题的解决办法。
我们在写信或寄快递的时候都要填上邮编和手机号来标明邮寄的目的地和收件人,如果我们有一个程序能够自动识别出不同人手写的数字,并将其录入计算机中,无疑会大大增加邮寄的效率。然而,由于每个人的书写习惯不同,要直接找出一个适用于每个人的识别算法殊为不易。这一章我们来介绍通过逻辑回归的方法,让机器自动学习手写字的特征,进而识别其代表的数字。
首先,我们需要有一个训练集。如下图所示,吴老师的课程提供了5000个不同人手写的数字图片,并且标出了每一幅图片对应的正确数字。
这些图片和相应的数字标签就是我们的训练集,每一幅图片的分辨率都是 20 × 20 20×20 20×20的,对应于400个实数。我们将每一幅图片的这400个数字作为输入特征,再加上一个常数项1,共有401个特征。
与二分类问题不同,这里的输出结果是0-9共十个数字,因此处理方法也有所不同。多分类问题的处理方法通常是将其划分为多个二分类问题来处理。在本例中,我们一共有十个类别,因此就要运行十次逻辑回归算法,分别识别出某一幅图片是 0 , 1 , 2 , ⋯   , 9 0,1,2,\cdots,9 0,1,2,⋯,9的概率,取概率最大的那个数字作为最终的识别结果。
具体实现过程,我们下面结合代码来详细讲解。
首先,导入数据,可以看到,数据中有 X X X和 y y y两个变量,其中 X X X是一个 5000 × 400 5000×400 5000×400的矩阵,每一行代表一个手写数字的图片,共有5000幅图, y y y是一个 5000 × 1 5000×1 5000×1的向量,每一行都是一个0到9之间的整数,代表对于行的 X X X的标签。
load('ex3data1.mat');
接下来,我们通过下面的代码来将一部分手写数字图片显示出来。思路是,先从训练集中随机抽取出100行 X X X数据,然后将每一行 X X X整形为 20 × 20 20×20 20×20的矩阵,最后再将这100个矩阵按 10 × 10 10×10 10×10的规律组合成一个大的矩阵,最后用imshow函数将这个大矩阵显示出来,结果就是这一章最开始的那幅图片了。
%Display the data;
m=length(y);
rand_indices=randperm(m);
rand_display_x=X(rand_indices(1:100),:);
rand_display_y=y(rand_indices(1:100),:);
img_display=zeros(200,200);
figure,
for k=1:100
x_indice=mod(k-1,10)*20+1;
y_indice=floor((k-1)/10)*20+1;
img_display(y_indice:y_indice+19,x_indice:x_indice+19)=reshape(rand_display_x(k,:),20,20);
end
imshow(img_display);
在对训练集有了一个直观的了解后,我们下一步就开始进入数字识别的算法中去。首先,我们要给 X X X增加一列常数项1作为偏置项,然后将这些图片的顺序打乱,让我们的识别算法能够更好地工作。
%Prepare the test samples.
K=10;
x=[ones(m,1),X];
x=x(rand_indices,:);
y=y(rand_indices,:);
接下来,我们冲总的训练集中抽出70%的样本作为真正的训练集,另外30%的样本做为测试集,用来等我们的算法训练完成后测试其最终的正确率。
test_num=round(0.7*m);
xt=x(1:test_num,:);
yt=y(1:test_num);
下一步:初始化参数向量 θ \mathbf{\theta} θ,和正则化参数 λ \lambda λ。这里,由于我们要运行十次逻辑回归算法,因此会得到十个参数向量,我们将这些向量放在一起组成一个参数矩阵Theta,这个矩阵在后面会非常有用。
[m,n]=size(xt);
theta=zeros(n,1);
lamda=0.02;
Theta=zeros(n,K);
接下来,就是利用高级优化算法fmincg分别学习数字 0 , 1 , 2 , ⋯   , 9 0,1,2,\cdots,9 0,1,2,⋯,9的识别参数。具体而言,就是针对每一个数字,例如‘3’,我们将 y y y中等于‘3’的数字置为1,不等于‘3’的数字置为0。这样我们的训练集就变成一个二分类问题的训练集了,然后就可以用上一章的方法寻找最优的参数theta了。重复以上过程十次,就可以得到针对每一个数字的最优参数了。
这里的代价函数CostFunctionReg与上一章中使用的代价函数是一样的。而这里使用的高级优化算法fmincg吴老师的作业中也给出了其代码(见文末下载),我们不必深究,直接拿来使用即可。到这里,我们就完成了数字识别的学习过程,是不是很简单呀。
%Repeat the minimization process for every label.
options = optimset('GradObj', 'on', 'MaxIter', 500);
for k=1:K
yi=yt==k;
[theta, cost] =fmincg(@(t)(CostFunctionReg(t, xt, yi,lamda)), theta, options);
Theta(:,k)=theta;
end
然而,到这里,我们只是得到了一个 400 × 10 400×10 400×10的参数矩阵Theta,这个矩阵到底能不能帮我们正确识别手写数字图片呢?这就要用到我们的测试集了。
如下所示,因为我们学习的过程只用到了70%的样本,如果我们学到的参数是正确的话,那么我们应该能够准确识别剩下的30%的样本也就是测试集中的图片。我们用xc和yc来表示测试集的特征和标签,通过代码h=Sigmoid(xc*Theta);[value,index]=max(h,[],2);
来计算我们预测的结果。看到矩阵Theta的好处了吧,我们只需要一行代码,就可以计算出所有数字的预测概率。其中,h是一个 0.3 m × 10 0.3m×10 0.3m×10的矩阵,其每一行代表的是我们预测的对应样本是 0 , 1 , 2 , ⋯   , 9 0,1,2,\cdots,9 0,1,2,⋯,9的概率,其最大值所在的列数就是我们最终的预测结果。将我们的预测结果与实际的正确结果进行对比,就可以得出我们预测的准确率了,当然你也可以调节正则化参数lambda看看其对准确率的影响,我得到的准确率在92%左右。
%Check the accuracy of prediction.
test_num=round(0.7*m);
xc=x(test_num+1:end,:);
yc=y(test_num+1:end);
h=Sigmoid(xc*Theta);
[value,index]=max(h,[],2);
result=[yc,index];
accuracy=sum(index==yc)/length(yc);
如果我们想享受一下自己的预测过程呢,可以通过以下代码来逐个显示测试集中的图片和我们的预测结果了。
%Display the predicting process
figure,
for i=1:100
data=x(test_num+i,:);
img=reshape(data(2:end),20,20);
%subplot(10,10,i),
img=imresize(img,30);
imshow(img);
h=Sigmoid(data*Theta);
[value,index]=max(h);
if index==10
index=0;
end
label=y(test_num+i);
if label==10
label=0;
end
xlabel([num2str(label),',',num2str(index),'(',num2str(value),')']);
drawnow;
pause(1.5);
end
以上这段代码的运行结果如下图所示,图片下方的数字意思为:正确值,预测值(正确的概率)。如果我们运行了这段程序就会发现,我们的准确率之所以不能达到100%实在是因为有些人写的数字太潦草了,连我们自己都没法准确识别,更别说我们的电脑了。。
好了,到这里,我们的手写数字识别的小程序就完成了。
以上代码可以从这里下载,提取码3kua
这一章通过一个手写数字识别的小程序介绍了多分类问题的处理方法,以及如何验证我们学习算法的效果。下一章将介绍另一个手写数字识别的算法:神经网络算法。