深度学习在执行梯度下降算法时,通常会面临一系列的问题。如陷入local minimun、saddle point,训练很慢或不收敛等诸多问题。因此需要对梯度下降算法进行优化,优化的考量主要有三个方面:
我们假设某一个数据集的训练集中有 10000 个样本。
在BGD的一个epoch中,更新一次参数,需要用到训练集中所有样本(10000个),通过计算这 10000 个输入样本的 loss function,进行反向传播来更新参数。即一次迭代需要用到训练集中的 所有数据。每一个epoch只更新一次参数。
θ = θ − η ∇ θ J ( θ ) \theta = \theta - \eta \nabla_{\theta} J(\theta) θ=θ−η∇θJ(θ)
优点:由于考虑到所有的样本,容易得到全局最优解,同时利于并行计算;
缺点:更新速度慢,占用较多内存。
SGD与BGD恰恰相反,在SGD中,每计算一个样本的loss function,就做一次反向传播来更新参数。即每一次参数更新只需要一个样本。在一个epoch中,参数更新 10000 次,训练速度显然快于SGD。
θ = θ − η ∇ θ J ( θ , x ( i ) , y ( i ) ) \theta = \theta - \eta \nabla_{\theta} J(\theta,x^{(i)},y^{(i)}) θ=θ−η∇θJ(θ,x(i),y(i))
优点:训练速度快,占用内存少;
缺点:由于只根据一个样本进行梯度下降,所以很难得到全局最优,不易实现并行计算。
我们看到,SGD和BGD都有各自的优缺点,那么我们能不能做一个折衷呢?汲取各自的优点,弥补各自的缺点,Mini-batch Gradient Descent既考虑参数训练的速度,又考虑到训练的稳定性,以得到全局最优。Mini-batch是将训练集数据分为若干个数据量相同的batch,假设一个batch中有100个样本,则整个数据集(假设10000个样本)中就有100个batch。每更新一次参数,需要利用到一个batch的数据,这样既不像BGD训练得那么慢,又不会像SGD那样失去全局的考虑。
θ = θ − η ∇ θ J ( θ , x ( i + n ) , y ( i + n ) ) \theta = \theta - \eta \nabla_{\theta} J(\theta,x^{(i+n)},y^{(i+n)}) θ=θ−η∇θJ(θ,x(i+n),y(i+n))
以上三种算法存在两点共同的缺陷:
为了解决上述的第一个问题,即动态确定学习率,AdaGrad给出了相当不错的思路:在训练中有非常多的参数,有些更新的快有些更新较慢,所以AdaGrad为每一个参数都确定一个属于他自身的学习率。对于某些更新较快的参数,我们已经学习到他的一些知识,所以不希望单个样本能对其有太大的影响,因此学习率要设定的慢一些;对于更新很慢的参数,我们对其了解的信息太少,所以要将这个参数的学习率设定的更大一些。
假设 g t , i g_{t,i} gt,i 表示 t t t 时刻第 i i i 个参数的梯度,则此时梯度下降更新参数的公式为:
θ t + 1 , i = θ t , i − η ∑ j = 0 t ( g j , i ) 2 + ϵ g t , i \theta_{t+1,i}=\theta_{t,i}-\frac{\eta}{\sqrt{\sum_{j=0}^{t}(g_{j,i})^{2}+\epsilon}}g_{t,i} θt+1,i=θt,i−∑j=0t(gj,i)2+ϵηgt,i
通过上式可见,AdaGrad在初始学习率 η \eta η 的基础上,对于参数 i i i 除以该参数历史梯度的均方根, ϵ \epsilon ϵ 是防止分母为0的平滑项。这种做法消除了学习率恒定的弊端,使得频繁更新的参数学习率更小,更新缓慢的参数拥有较大的学习率。
但是AdaGrad存在一个问题,就是在训练的中后期,由于分母上梯度的不断累加,会导致学习率变得非常缓慢,在达到最优结果之前提前结束训练过程。
RMSProp将当前的梯度 g t , i g_{t,i} gt,i 与历史梯度分开讨论。设定一个0-1之间的参数 α \alpha α,每次算梯度时计算:
E [ g 2 ] t , i = α E [ g 2 ] t − 1 , i + ( 1 − α ) g t , i 2 E[g^{2}]_{t,i}=\alpha E[g^{2}]_{t-1,i}+(1-\alpha)g_{t,i}^{2} E[g2]t,i=αE[g2]t−1,i+(1−α)gt,i2
所以此时梯度更新的公式如下:
θ t + 1 , i = θ t , i − η E [ g 2 ] t , i + ϵ g t , i \theta_{t+1,i}=\theta_{t,i}-\frac{\eta}{\sqrt{E[g^{2}]_{t,i}+\epsilon}}g_{t,i} θt+1,i=θt,i−E[g2]t,i+ϵηgt,i
通过第一个公式可以发现,当 α \alpha α 的值接近于 0 时,当前的梯度 g t , i 2 g_{t,i}^{2} gt,i2 对学习率的确定起到更大的影响;当 α \alpha α 的值接近于 1 时,历史梯度的加权 E [ g 2 ] t − 1 , i E[g^{2}]_{t-1,i} E[g2]t−1,i 对学习率的影响占更大比重。
这个做法一定程度上解决了AdaGrad的缺点,在训练的中后期,万一出现梯度变稀疏的情况,我们希望学习率变大一些,而AdaGrad的单调性无法实现这一愿望。而RMSProp可根据当前的梯度和历史的梯度期望,做一个加权的考量。
在不引入Momentum之前,训练中很容易困在saddle point。即这个点loss很高,但是周围是平坦的,梯度为0,对应地理中的高原地形。此时需要引入类似于物理学中的冲量概念,即不能单单考虑这个点的梯度信息,同时要考虑上一个时刻的运动状态(类似于惯性)。有了这个考量,参数就有可能冲出鞍点的范围,继续寻找最优解。
上图摘自台大李宏毅老师的机器学习PPT。我们看到,在上图中红色箭头表示梯度方向;蓝色箭头代表参数更新的方向;绿色虚线代表上一次的参数更新方向。在第一次参数更新时,由于没有历史信息,所以参数更新方向就是梯度的反方向。在第二次及以后的参数更新时,参数的更新方向为上一次更新方向与此时梯度的反方向的加权向量和,即蓝色向量等于绿色向量与红色向量的加权和。
此时参数的更新形式如下:
θ = θ − v t \theta = \theta - v_{t} θ=θ−vt
而这里的 v t v_{t} vt 是当前梯度与前一个状态参数更新方向的向量加权和,可表示为:
v t = γ v t − 1 + ( − α g t ) v_{t} = \gamma v_{t-1}+(-\alpha g_{t}) vt=γvt−1+(−αgt)
可见每一次参数更新时会考虑到上一个时刻的参数更新方向 v t − 1 v_{t-1} vt−1。试想一下,假设将一个小球从坡面高处抛下,在下滑的过程中,如果遇到鞍点,这时参数更新方向也会参照之前的方向,继续向前向下,摆脱鞍点。所以这个冲量的引入使得参数摆脱鞍点或局部最优成为可能。
Adam算法可以看作RMSProp和Momentum的集大成者,兼顾了自适应学习率和冲量这两个因素,可以说是现在最稳定、效果最理想的训练参数方法之一。
首先其更新参数的公式如下:
θ t + 1 = θ t − η v t ^ + ϵ m t ^ \theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v_{t}}+\epsilon}}\hat{m_{t}} θt+1=θt−vt^+ϵηmt^
将 m t m_{t} mt 和 v t v_{t} vt 看作是对梯度的一阶矩估计和二阶矩估计,近似等于期望 E [ g t ] E[g_{t}] E[gt] 和 E [ g t 2 ] E[g_{t}^{2}] E[gt2],所以 m t m_{t} mt 和 v t v_{t} vt 的表达形式如下:
m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt−1+(1−β1)gt
v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt−1+(1−β2)gt2
而在迭代开始的阶段, m t m_{t} mt 和 v t v_{t} vt 有一个向初始值0的偏移,因此第一个式子里的 m t ^ \hat{m_{t}} mt^ 和 v t ^ \hat{v_{t}} vt^ 分别是 m t m_{t} mt 和 v t v_{t} vt 的无偏估计,可以表达为:
m t ^ = m t 1 − β 1 t \hat{m_{t}}=\frac{m_{t}}{1-\beta_{1}^{t}} mt^=1−β1tmt
v t ^ = v t 1 − β 2 t \hat{v_{t}}=\frac{v_{t}}{1-\beta_{2}^{t}} vt^=1−β2tvt
将上面两个式子带入第一个公式即可。
参考:
李宏毅老师机器学习课程
https://blog.csdn.net/u010089444/article/details/76725843
https://zhuanlan.zhihu.com/p/32626442