深度学习入门之AdaGrad

在神经网络的学习中,学习率(数学式中记为 η )的值很重要。学习率过小,会导致学习花费过多时间;反过来,学习率过大,则会导致学习发散而不能正确进行。

在关于学习率的有效技巧中,有一种被称为学习率衰减(learning rate decay)的方法,即随着学习的进行,使学习率逐渐减小。即,一开始“多”学,然后逐渐“少”学。(学习的意思是朝着损失函数最低处进行优化)

AdaGrad 会为参数的每个元素适当地调整学习率,与此同时进行学习(AdaGrad 的 Ada 来自英文单词 Adaptive,即“适当的”的意思)。逐渐减小学习率的想法,相当于将“全体”参数的学习率值一起降低。

而 AdaGrad进一步发展了这个想法,针对“一个一个”的参数,赋予其“定制”的值。下面,让我们用数学式表示 AdaGrad 的更新方法。

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5a2d5q2j5ZCM5a2m,size_15,color_FFFFFF,t_70,g_se,x_16

watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA5a2d5q2j5ZCM5a2m,size_16,color_FFFFFF,t_70,g_se,x_16

和 SGD (随机梯度下降算法)一样,ffd71c928e974fa6880a6770becd112c.png  表示要更新的权重参数,58032a0447844d8c98a8a8a5b225a0d7.png 表示损失函数关于 774086325de045ee8148212ed4114007.png  的梯度,η 表示学习率。这里新出现了变量 b32c4bc936874bf993f4da060e62f079.png  ,如式 (6.5) 所示,它保存了以前的所有梯度值的平方和(式(6.5)中的 a163de5dd642444e9e3ca86fc3715f46.png  表示对应矩阵元素的乘法)。然后,在更新参数时,通过乘以 d5bb729c08d843d6a41d7a0faf1d95e5.png  ,就可以调整学习的尺度。这意味着,参数的元素中变动较大(被大幅更新)的元素的学习率将变小。也就是说,可以按参数的元素进行学习率衰减,使变动大的参数的学习率逐渐减小。

AdaGrad 会记录过去所有梯度的平方和。因此,学习越深入,更新的幅度就越小。实际上,如果无止境地学习,更新量就会变为 0,完全不再更新。

为了改善这个问题,可以使用 RMSProp 方法。RMSProp 方法并不是将过去所有的梯度一视同仁地相加,而是逐渐地遗忘过去的梯度,在做加法运算时将新梯度的信息更多地反映出来。这种操作从专业上讲,称为“指数移动平均”,呈指数函数式地减小过去的梯度的尺度。

现在来实现 AdaGrad。AdaGrad 的实现过程如下所示。 

class AdaGrad:
    def __init__(self, lr=0.01):
        self.lr = lr#学习率
        self.h = None
    def update(self, params, grads):
        if self.h is None:
        	self.h = {}
        	for key, val in params.items():
            		self.h[key] = np.zeros_like(val)
    	for key in params.keys():
        	self.h[key] += grads[key] * grads[key]
        	params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)

这里需要注意的是,最后一行加上了微小值 1e-7 。这是为了防止当 self.h[key] 中有 0 时,将 0 用作除数的情况。在很多深度学习的框架中,这个微小值也可以设定为参数,但这里我们用的是 1e-7 这个固定值。

现在,让我们试着使用 AdaGrad 解决式(6.2)的最优化问题,结果如图 6-6 所示。

深度学习入门之AdaGrad_第1张图片

图 6-6 基于 AdaGrad 的最优化的更新路径

由图 6-6 的结果可知,函数的取值高效地向着最小值移动。由于 y 轴方向上的梯度较大,因此刚开始变动较大,但是后面会根据这个较大的变动按比例进行调整,减小更新的步伐。因此,y 轴方向上的更新程度被减弱,“之”字形的变动程度有所衰减。

你可能感兴趣的:(深度学习入门,深度学习,cnn,p2p)