AndrewNg机器学习第四周作业:关于使用逻辑回归、神经网络训练数据并应用之的心得

ex3的作业是根据已有的数据集

(20*20像素的图片,每个像素是一个feature,总共400个features,400个features作为输入X,数据集已经包含输出的y,代表这是什么数字)

,去识别手写数字。

首先是使用逻辑回归方法来分类10个数字(分类问题)。

一、逻辑回归 参数theta的训练与预测

一开始的theta矩阵是ones创建的,构建逻辑回归的cost Function和gradient,然后把数据集X和已知结果y扔到里面去训练(这个“训练”过程实际就是一种代数运算),求出我们的theta(维度:10*401),这里需要讲一下为什么维度是10*401,


我们要识别的数字是10个,然后输入的features是400个,为甚么后来变成了401呢,因为多了一个theta0,所以是401。
然后因为我们是分别对10个数字做逻辑回归,theta的每一行代表对应这个数字我们的theta参数值
通俗来讲,就是,我们想象出一个401维的坐标轴,然后里面有1000个标记点(分别代表1到10这些数字),然后我们现在要构建一个超级复杂的函数hx1(拥有400个参数)来将代笔1这个数字的标记点包起来(不一定要完美无缺的包起来),同理,我们还要找到hx2、hx3····来把其他数字标记点包起来,所以这样就构成了我们的theta矩阵。


那么,我们的theta矩阵训练完之后我们要怎么去预测一个新数据集呢?

现在给出一个新的数据集X2(1000*401)然后把这个X2*(theta的转置)矩阵相乘之后 带入逻辑函数中(即S型函数)后得到一个 矩阵B(1000*10),矩阵B的每一行代表 数据集X2对应的那一行数据输入的结果预测,矩阵B第一行第一列的数值代表X2第一行输入的features经过运算后逻辑回归判断是数字1的可能性,之后以此类推,我们记录矩阵B中每一行最大的数值所在的列,即得出了我们预测的数字。

二、神经网络 参数theta的训练与预测

神经网络与逻辑回归的处理方式基本一样,只不过theta矩阵 视 神经网络的层数和每层的神经元个数不同~~
附上神经网络预测的代码:

function p = predict(Theta1, Theta2, X)
%PREDICT Predict the label of an input given a trained neural network
%   p = PREDICT(Theta1, Theta2, X) outputs the predicted label of X given the
%   trained weights of a neural network (Theta1, Theta2)

% Useful values
m = size(X, 1);
num_labels = size(Theta2, 1);

% You need to return the following variables correctly 
p = zeros(size(X, 1), 1);

% ====================== YOUR CODE HERE ======================
% Instructions: Complete the following code to make predictions using
%               your learned neural network. You should set p to a 
%               vector containing labels between 1 to num_labels.
%
% Hint: The max function might come in useful. In particular, the max
%       function can also return the index of the max element, for more
%       information see 'help max'. If your examples are in rows, then, you
%       can use max(A, [], 2) to obtain the max for each row.
%


a1=[ones(m,1) X];
z2=Theta1*a1';
a2=[ones(1,m);sigmoid(z2)];
z3=Theta2*a2;
output=z3';
[c,i]=max(output,[],2);
p=i;






% =========================================================================


end

你可能感兴趣的:(机器学习,机器学习,神经网络)