动手学深度学习:7.6 RMSProp算法

7.6 RMSProp算法

我们在7.5节(AdaGrad算法)中提到,因为调整学习率时分母上的变量stst \boldsymbol{s}_tf(x)=0.1x12+2x22中自变量的迭代轨迹。回忆在7.5节(AdaGrad算法)使用的学习率为0.4的AdaGrad算法,自变量在迭代后期的移动幅度较小。但在同样的学习率下,RMSProp算法可以更快逼近最优解。

%matplotlib inline
import math
import torch
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def rmsprop_2d(x1, x2, s1, s2):
g1, g2, eps = 0.2 x1, 4 x2, 1e-6
s1 = gamma s1 + (1 - gamma) g1 2
s2 = gamma s2 + (1 - gamma) g2 2
x1 -= eta / math.sqrt(s1 + eps) g1
x2 -= eta / math.sqrt(s2 + eps) g2
return x1, x2, s1, s2

def f_2d(x1, x2):
return 0.1 x1 ** 2 + 2 x2 ** 2

eta, gamma = 0.4, 0.9
d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))Copy to clipboardErrorCopied

输出:

epoch 20, x1 -0.010599, x2 0.000000Copy to clipboardErrorCopied

7.6.2 从零开始实现

接下来按照RMSProp算法中的公式实现该算法。

features, labels = d2l.get_data_ch7()

def init_rmsprop_states():
s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)
s_b = torch.zeros(1, dtype=torch.float32)
return (s_w, s_b)

def rmsprop(params, states, hyperparams):
gamma, eps = hyperparams[‘gamma’], 1e-6
for p, s in zip(params, states):
s.data = gamma s.data + (1 - gamma) (p.grad.data)**2
p.data -= hyperparams[‘lr’] * p.grad.data / torch.sqrt(s + eps)Copy to clipboardErrorCopied

我们将初始学习率设为0.01,并将超参数γγ \gammagtgt的加权平均。

d2l.train_ch7(rmsprop, init_rmsprop_states(), {‘lr’: 0.01, ‘gamma’: 0.9},
features, labels)
Copy to clipboardErrorCopied

输出:

loss: 0.243452, 0.049984 sec per epochCopy to clipboardErrorCopied

7.6.3 简洁实现

通过名称为RMSprop的优化器方法,我们便可使用PyTorch提供的RMSProp算法来训练模型。注意,超参数γγ \gammaγ通过alpha指定。

d2l.train_pytorch_ch7(torch.optim.RMSprop, {'lr': 0.01, 'alpha': 0.9},
                    features, labels)Copy to clipboardErrorCopied

输出:

loss: 0.243676, 0.043637 sec per epochCopy to clipboardErrorCopied

小结

  • RMSProp算法和AdaGrad算法的不同在于,RMSProp算法使用了小批量随机梯度按元素平方的指数加权移动平均来调整学习率。

参考文献

[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 26-31.


注:除代码外本节与原书此节基本相同,原书传送门

你可能感兴趣的:(#,深度学习)