机器学习笔记——神经网络的实现

0x00 前言

这篇算是神经网络学习的一篇实现总结,神经网络算法无疑是最好的实现人类“人工智能”的算法,与安全相关的是,神经网络已经在WAF领域、SQL注入检测、Webshell检测领域有了方法论,为了最终能和安全技术相结合,首先要搞清楚这个算法的整个过程。
这里不会详细的从各个层面总结神经网络算法,而是通过实现的代码来学习,我想这个方法无疑是学习任何机器学习模型的最好方法。

0x01 训练的步骤

这里的训练步骤是使用了吴恩达博士的slide,有兴趣的可以去Coursera上听一听。
(1)随机初始化每层的theta参数矩阵
(2)对每一个 xj ,通过前向传导算法(FP)获取 hθ(x(i))
(3)计算Cost Function, J(θ)
其中 J(θ) 的表达式为:

(4)使用反向传播算法(BP)计算 ϑϑΘ(l)jkJ(Θ)
(5)使用梯度检测比较BP得出的梯度和数值计算出来的梯度是否基本相同
(6)在反向传播算法的基础上使用梯度下降和其他优化算法最小化 J(θ)

0x02 FP算法步骤

FP算法如下,用来从第一层开始,逐步向后传导,最终求出最后一层即输出层(output layer)的输出函数:

训练集为5000个,其中有400个维度(不算偏置单元的theta),Octave实现如下:

% 前向传导算法实现,计算出layer3的h函数
layer1 = [ones(m,1) X]; % 5000 * 401
layer2 = [ones(m,1) sigmoid(layer1*Theta1')]; % 5000 * 26
layer3 = sigmoid(layer2*Theta2');  % 5000 * 10

可以看到FP算法就是一层一层从前向后逐步求出output layer的值。

0x03 BP算法步骤

在求每层单元误差的时候,必须使用后向传播算法,步骤如下,同样来自于吴恩达博士的slide:

对于每个训练集样本,我们都要执行一次BP算法来计算对应的Delta的值。
关于delta的计算,首先要算出输出层的delta,然后后向传播:

后向传播的初始化就是输出层的误差delta,在上面一张slide上也能很清楚的看到传播的计算过程。
具体实现如下:

% 实现后向传播算法
Delta1 = zeros(size(Theta1));
Delta2 = zeros(size(Theta2));
for i=1:m,
  % 前向传导算法求h(theta)
  layer1 = [1; X(i,:)’]; % 401 * 1
  z2 = Theta1 * layer1;
  layer2 = [1; sigmoid(Theta1*layer1)]; % 26 * 1
  layer3 = sigmoid(Theta2*layer2);   % 10 * 1

  % 计算layer3(最后一层) 的error
  % h_delta = a3-y(i)
  delta3 = zeros(num_labels, 1); % 10 * 1
  for c=1:num_labels,
    delta3(c) = layer3(c) - (y(i)==c);
  end;

  % 根据layer3计算前面的delta
  % 计算deta2 = (theta(2))’*delta3 .* (a(2) .* (1-a(2)))
  delta2 = Theta2’ * delta3;
  delta2 = delta2(2:end) .* sigmoidGradient(z2);  % 25 * 1

  % 计算Delta
  Delta2 = Delta2 + delta3 * layer2’ ;
  Delta1 = Delta1 + delta2 * layer1’ ;
end;

% step5 非正规化与正规化的区别在于 lambda 等于0 和 lambda 不等于 0
% 正则化不包括偏置单元,所以置theta的第一列为0
Theta1_grad = Delta1 / m + lambda/m * [zeros(size(Theta1,1),1) Theta1(:,2:end)];
Theta2_grad = Delta2 / m + lambda/m * [zeros(size(Theta2,1),1) Theta2(:,2:end)];

0x04 总结

神经网络算是一个比较复杂的机器学习算法了,复杂的好处在于运用这种算法,可以轻松解决一些维数很多的样本集进行学习。比如谷歌一直研发的自动驾驶技术、百度brain等一些机器学习项目,基本都是基于神经网络来进行实现。所以,使用神经网络来解决安全问题,似乎有点儿杀鸡用牛刀的意思。

你可能感兴趣的:(机器学习&数据挖掘)