pytorch优化器详解:Adam

目录

说明

Adam原理

梯度滑动平均

偏差纠正

Adam计算过程

pytorch Adam参数

params

lr

betas

eps

weight_decay

amsgrad


说明

模型每次反向传导都会给各个可学习参数p计算出一个偏导数g_t,用于更新对应的参数p。通常偏导数g_t不会直接作用到对应的可学习参数p上,而是通过优化器做一下处理,得到一个新的\widehat{g}_t,处理过程用函数F表示(不同的优化器对应的F的内容不同),即\widehat{g}_t=F(g_t),然后和学习率lr一起用于更新可学习参数p,即p=p-\widehat{g}_t*lr

Adam是在RMSProp和AdaGrad的基础上改进的。先掌握RMSProp的原理,就很容易明白Adam了。本文是在RMSProp这篇博客的基础上写的。

Adam原理

在RMSProp的基础上,做两个改进:梯度滑动平均偏差纠正

梯度滑动平均

在RMSProp中,梯度的平方是通过平滑常数平滑得到的,即v_t=\beta*v_{t-1}+(1-\beta)*(g_t)^2根据论文,梯度平方的滑动均值用v表示;根据pytorch源码,Adam中平滑常数用的是β,RMSProp中用的是α),但是并没有对梯度本身做平滑处理。

在Adam中,对梯度也做了平滑,平滑后的滑动均值用m表示,即m_t=\beta*m_{t-1}+(1-\beta)*g_t,在Adam中有两个β。

偏差纠正

上述m的滑动均值的计算,当t=1时,m_1=\beta*m_0+(1-\beta)*g_1,由于m_0的初始是0,且β接近1,因此t较小时,m的值是偏向于0的,v也是一样。这里通过除以1-\beta^t来进行偏差纠正,即\widehat{m}_t=\frac{m_t}{1-\beta^t}

Adam计算过程

为方便理解,以下伪代码和论文略有差异,其中蓝色部分是比RMSProp多出来的。

  1. 初始:学习率 lr
  2. 初始:平滑常数(或者叫做衰减速率) \beta_1,\beta_2,分别用于平滑m和v
  3. 初始:可学习参数 \theta_0
  4. 初始:m_0=0,v_0=0,t=0
  5. while 没有停止训练 do
  6.         训练次数更新:t=t+1
  7.         计算梯度:g_t(所有的可学习参数都有自己的梯度,因此 g_t表示的是全部梯度的集合)
  8.         累计梯度:{\color{Blue} m_t=\beta_1*m_{t-1}+(1-\beta_1)*g_t}(每个导数对应一个m,因此m也是个集合)
  9.         累计梯度的平方:v_t=\beta_2*v_{t-1}+(1-\beta_2)*(g_t)^2(每个导数对应一个v,因此v也是个集合)
  10.         偏差纠正m:{\color{Blue} \widehat{m}_t=\frac{m_t}{1-(\beta_1)^t}}
  11.         偏差纠正v:{\color{Blue} \widehat{v}_t=\frac{v_t}{1-(\beta_2)^t}}
  12.         更新参数:\theta_t=\theta_{t-1}-\frac{\widehat{m}_t}{\sqrt{\widehat{v}_t}+\epsilon}lr
  13. end while

pytorch Adam参数

torch.optim.Adam(params,
                lr=0.001,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0,
                amsgrad=False)

params

模型里需要被更新的可学习参数

lr

学习率

betas

平滑常数\beta_1\beta_2

eps

\epsilon,加在分母上防止除0

weight_decay

weight_decay的作用是用当前可学习参数p的值修改偏导数,即:g_t=g_t+(p*weight\_decay),这里待更新的可学习参数p的偏导数就是g_t

weight_decay的作用是L2正则化,和Adam并无直接关系。

amsgrad

如果amsgrad为True,则在上述伪代码中的基础上,保留历史最大的v_t,记为v_{max},每次计算都是用最大的v_{max},否则是用当前v_t

amsgrad和Adam并无直接关系。

你可能感兴趣的:(pytorch,深度学习,Adam算法,深度学习,adam算法)