干货 :深入浅出神经网络的改进方法!

高尔夫球员刚开始学习打高尔夫球时,通常会花很长时间练习挥杆。慢慢地,他们才会在此基础上练习其他击球方式,学习削球、左曲球和右曲球。本章仍着重介绍反向传播算法,这就是我们的“挥杆基本功”——神经网络中大部分工作、学习和研究的基础。

本文将着重讲解利用交叉熵代价函数改进神经网络的学习方法。

一、交叉熵代价函数

大多数人不喜欢被他人指出错误。我以前刚学习弹钢琴不久,就在听众前做了一次首秀。我很紧张,开始时错将八度音阶的曲段演奏得很低。我不知所措,因为演奏无法继续下去了,直到有人指出了其中的错误。我当时非常尴尬。不过,尽管不愉快,我们却能因为明显的错误而快速地学到正确的知识。下次我肯定能演奏正确!然而当错误不明确的时候,学习会变得非常缓慢。学习速度下降的原因实际上也是一般的神经网络学习缓慢的原因,并不仅仅是特有的。

引入交叉熵代价函数

如何解决这个问题呢?研究表明,可以使用交叉熵代价函数来替换二次代价函数。

将交叉熵看作代价函数有两点原因。第一,它是非负的,C > 0。可以看出(57)的求和中的所有单独项都是负数,因为对数函数的定义域是(0, 1)。求和前面有一个负号。

第二,如果对于所有的训练输入x,神经元实际的输出都接近目标值,那么交叉熵将接近0。假设在本例中,y = 0而a ≈ 0,这是我们想要的结果。方程(57)中的第一个项会消去,因为y = 0,而第二项实际上就是−ln(1 − a) ≈ 0;反之,y = 1而a ≈ 1。所以实际输出和目标输出之间的差距越小,最终交叉熵的值就越小。

综上所述,交叉熵是非负的,在神经元达到较高的正确率时接近0。我们希望代价函数具备这些特性。其实二次代价函数也拥有这些特性,所以交叉熵是很好的选择。然而交叉熵代价函数有一个比二次代价函数更好的特性:它避免了学习速度下降的问题。

代价函数曲线要比二次代价函数训练开始部分陡峭很多。这个交叉熵导致的陡度正是我们期望的,当神经元开始出现严重错误时能以最快的速度学习。

二、改进神经网络的应用实践

使用交叉熵来对MNIST数字进行分类

如果程序使用梯度下降算法和反向传播算法进行学习,那么交叉熵作为其中一部分易于实现。我们会使用一个包含30个隐藏神经元的网络,小批量的大小也设置为10,将学习率设置为 η ,训练30轮。network2.py的接口和network.py的略有区别,但用法还是很好懂的。可以在Python shell中使用help(network2.Network.SGD)这样的命令来查看network2.py的接口文档。

>>> import mnist_loader
>>> training_data, validation_data, test_data = \
... mnist_loader.load_data_wrapper()
>>> import network2
>>> net = network2.Network([784, 30, 10], cost=network2.CrossEntropyCost)
>>> net.large_weight_initializer()
>>> net.SGD(training_data, 30, 10, 0.5, evaluation_data=test_data,
... monitor_evaluation_accuracy=True)

注意,net.large_weight_initializer()命令使用第1章介绍的方式来初始化权重和偏置。这里需要执行该命令,因为后面才会改变默认的权重初始化命令。运行上面的代码,神经网络的准确率可以达到95.49%,这跟第1章中使用二次代价函数得到的结果(95.42%)相当接近了。

对于使用100个隐藏神经元,而交叉熵及其他参数保持不变的情况,准确率达到了96.82%。相比第1章使用二次代价函数的结果(96.59%)有一定提升。看起来是很小的变化,但考虑到误差率已经从3.41%下降到3.18%了,消除了原误差的1/14,这其实是可观的改进。

跟二次代价相比,交叉熵代价函数能提供类似的甚至更好的结果,然而这些结果不能证明交叉熵是更好的选择,原因是在选择学习率、小批量大小等超参数上花了一些心思。为了让提升更有说服力,需要对超参数进行深度优化。然而,这些结果仍然是令人鼓舞的,它们巩固了先前关于交叉熵优于二次代价的理论推断。

转自:Datawhale ; Michael Nielsen,计算机科学家;

本文内容节选自《深入浅出神经网络与深度学习》一书,由Michael Nielsen所著,李航等推荐。本书深入了讲解神经网络和深度学习技术,作者以技术原理为导向,辅以贯穿全书的 MNIST 手写数字识别项目示例,讲解了如何利用所学知识改进深度学习项目,值得学习。

END

版权声明:本号内容部分来自互联网,转载请注明原文链接和作者,如有侵权或出处有误请和我们联系。


合作请加QQ:365242293  

数据分析(ID : ecshujufenxi )互联网科技与数据圈自己的微信,也是WeMedia自媒体联盟成员之一,WeMedia联盟覆盖5000万人群。

你可能感兴趣的:(神经网络,算法,人工智能,深度学习,编程语言)