gradnorm论文地址:https://arxiv.org/abs/1711.02257
gradnorm是一种优化方法,在多任务学习(Multi-Task Learning)中,解决 1. 不同任务loss梯度的量级(magnitude)不同,造成有的task在梯度反向传播中占主导地位,模型过分学习该任务而忽视其它任务;2. 不同任务收敛速度不一致;这两个问题。
从实现上来看,gradnorm除了利用label loss更新神经网络的参数外,还会使用grad loss更新每个任务(task)的损失(loss)在总损失中的权重 w w w。
以简单的多任务学习模型shared bottom为例,两个任务的shared bottom结构如下,输出的两个tower分别拟合两个任务。
针对这样的模型,最简单的方法就是每个任务单独计算损失,然后汇总起来,最终的损失函数如下:
l o s s ( t ) = l o s s A ( t ) + l o s s B ( t ) loss(t) = loss_{A}(t)+loss_{B}(t) loss(t)=lossA(t)+lossB(t)
但是,两个任务的loss反向传播的梯度量级可能不同,在反向传播到shared bottom部分时,梯度量级小的任务对模型参数更新的比重少,使得shared bottom对该任务的学习不充分。因此,我们可以简单的引入权重,平衡梯度,如下:
l o s s ( t ) = w A × l o s s A ( t ) + w B × l o s s B ( t ) loss(t) =w_{A}\times loss_{A}(t)+w_{B}\times loss_{B}(t) loss(t)=wA×lossA(t)+wB×lossB(t)
这样做并没有很好的解决问题,首先,如果loss权重 w w w在训练过程中为定值,最初梯度量级大的任务,我们给一个小的 w w w,到训练结束,这个小的 w w w会一直限制这一任务,使得这一任务不能得到很好的学习。因此,需要梯度也是不断变化的,更新公式如下:
l o s s ( t ) = w A ( t ) × l o s s A ( t ) + w B ( t ) × l o s s B ( t ) loss(t) =w_{A}(t)\times loss_{A}(t)+w_{B}(t)\times loss_{B}(t) loss(t)=wA(t)×lossA(t)+wB(t)×lossB(t)
gradnorm就是用梯度,来动态调整loss的 w w w的优化方法。
想要动态更新loss的 w w w,最直观的方法就是利用grad,因为在多任务学习中,我们解决的就是多任务梯度不平衡的问题,如果我们能知道 w w w的更新梯度(这里的梯度不是神经网络参数的梯度,是loss权重 w w w的梯度),就可以利用梯度更新公式,来动态更新 w w w,就像更新神经网络的参数一样,如下,其中 λ \lambda λ沿用全局的神经网络学习率。
w ( t + 1 ) = w ( t ) + λ β ( t ) w(t+1) = w(t)+\lambda\beta (t) w(t+1)=w(t)+λβ(t)
我们的目的是平衡梯度,所以 β \beta β最好是梯度关于 w w w的导数,为此定义梯度损失如下:
G r a d L o s s = Σ i ∣ G W i ( t ) − G ‾ W ( t ) × [ r i ( t ) ] α ∣ Grad~Loss = \Sigma_{i}\Big|G_W^{i}(t)-\overline{G}_{W}(t)\times [r_i(t)]^{\alpha}\Big| Grad Loss=Σi∣∣∣GWi(t)−GW(t)×[ri(t)]α∣∣∣
G W i ( t ) = ∣ ∣ ▽ W w i ( t ) L i ( t ) ∣ ∣ 2 G_W^{i}(t)=||\bigtriangledown_Ww_i(t)L_i(t)||_2 GWi(t)=∣∣▽Wwi(t)Li(t)∣∣2
G ‾ W ( t ) = E t a s k [ G W i ( t ) ] \overline{G}_W(t)=E_{task}[G_W^i(t)] GW(t)=Etask[GWi(t)]
r i ( t ) = L ~ i ( t ) E t a s k [ L ~ i ( t ) ] r_i(t)=\frac{\widetilde{L}_{i}(t)}{E_{task}[\widetilde{L}_{i}(t)]} ri(t)=Etask[L i(t)]L i(t)
L ~ i ( t ) = L i ( t ) L 0 ( t ) \widetilde{L}_{i}(t)=\frac{L_{i}(t)}{L_{0}(t)} L i(t)=L0(t)Li(t)
这几个公式就是论文最核心的部分,其中, G r a d L o s s Grad~Loss Grad Loss定义为,各个任务实际的梯度范数与理想的梯度范数的差的绝对值和; G W i ( t ) G_W^{i}(t) GWi(t)为实际的梯度范数, G ‾ W ( t ) × [ r i ( t ) ] α \overline{G}_{W}(t)\times [r_i(t)]^{\alpha} GW(t)×[ri(t)]α为理想的梯度范数; G W i ( t ) G_W^{i}(t) GWi(t)是任务 i i i的带权损失 w i ( t ) L i ( t ) w_i(t)L_i(t) wi(t)Li(t),对需要更新的神经网络参数 W W W( W W W表示神经网络参数, w w w表示loss权重)的梯度的L2范数; G ‾ W ( t ) \overline{G}_W(t) GW(t)是对所有任务求得的 G W i ( t ) G_W^{i}(t) GWi(t)的平均; L ~ i ( t ) \widetilde{L}_{i}(t) L i(t)表示任务 i i i的反向训练速度, L ~ i ( t ) \widetilde{L}_{i}(t) L i(t)越大, L i ( t ) L_{i}(t) Li(t)越大,任务 i i i训练越慢; r i ( t ) r_i(t) ri(t)是任务 i i i的相对反向训练速度。
α \alpha α是超参数, α \alpha α越大,对训练速度的平衡限制越强。为了节约计算时间, G r a d L o s s Grad~Loss Grad Loss仅对shared bottom的输出部分计算。
有了 G r a d L o s s Grad~Loss Grad Loss,就可以利用 G r a d L o s s Grad~Loss Grad Loss对 w i ( t ) w_i(t) wi(t)求导,得到上面梯度更新公式中需要的 β ( t ) \beta(t) β(t)。为了防止 w i ( t ) w_i(t) wi(t)变为0,在对 G r a d L o s s Grad~Loss Grad Loss求导时,认为 G ‾ W ( t ) × [ r i ( t ) ] α \overline{G}_{W}(t)\times [r_i(t)]^{\alpha} GW(t)×[ri(t)]α部分为常数,即使其中有 w i ( t ) w_i(t) wi(t)。在每一个batch step的最后,为了节藕gradnorm过程中,利用 G r a d L o s s Grad~Loss Grad Loss对 w i ( t ) w_i(t) wi(t)求导过程与全局训练神经网络的学习率的关系,会对 w i ( t ) w_i(t) wi(t)在进行 Σ i w i ( t ) = T \Sigma_{i}w_i(t)=T Σiwi(t)=T的renormalize, T T T是任务总数。
gradnorm示意如下:
gradnorm在单个batch step的流程总结如下:
1.前向传播计算总损失 L o s s = Σ i w i l i Loss=\Sigma_iw_il_i Loss=Σiwili;
2.计算 G W i ( t ) G_W^{i}(t) GWi(t), r i ( t ) r_i(t) ri(t), G ‾ W i ( t ) \overline{G}_W^{i}(t) GWi(t);
3.计算 G r a d L o s s Grad~Loss Grad Loss;
4.计算 G r a d L o s s Grad~Loss Grad Loss对 w i w_i wi的导数;
5.利用第1步计算的的 L o s s Loss Loss反向传播更新神经网络参数;
6.利用第4步的导数更新 w i w_i wi(更新后在下一个batch step生效);
7.对 w i w_i wi进行renormalize(下一个batch step使用的是renormalize之后的 w i w_i wi)。
附上论文原版步骤:
参考文献:
https://github.com/brianlan/pytorch-grad-norm