loss收敛过小或finetune时跑飞情况分析

   最近在跑REDNet去噪网络,基于pytorch框架,输入输出图像被归一化到了0-1之间,loss使用的是L2 loss,理论最终收敛值为33左右。结果在之前训练好模型的基础上finetune,出现了这种情况:


loss收敛过小或finetune时跑飞情况分析_第1张图片

   当使用L1或者输入为0-255时,并没有出现这种情况。这里使用L2 loss,输入0-1之间,loss先下降,再上升,这是很反常的,一般来说有以下问题导致:
1.学习率设置不正确。
2.训练数据出现了偏差。
3.梯度更新出现问题,爆炸或者消失。

   这里后来经过分析,很可能是学习率设置导致,当然梯度更新不正确也有影响,学习率设置不正确分析如下:


loss收敛过小或finetune时跑飞情况分析_第2张图片

   也就是说,学习率如果设置过大,会导致其直接跑到另外一边,从而导致loss跑飞。当然真实的loss变化是在一个奇异空间里的,不是图上的那种二维曲线。
   但是最后还发现了一个现象,那就是最后的网络的收敛值只有30出头,而不是理论的33.多,也就是说网络学习不充分!如下图所示:


loss收敛过小或finetune时跑飞情况分析_第3张图片

   网络学习不充分,前面finetune的loss还跑飞了一段,很可能网络的梯度更新出现了问题,需要打印网络的每层梯度进行查看,代码如下所示:

self.optimizer.step()
for params in self.net.named_parameters():
    [name,param] = params
    print(name,':',param.grad.data.mean())
    print(count)

   迭代2000次,打印出来梯度值就是下图左图这个样子,右图是数据归一化到0-255之间时各层的梯度值:


loss收敛过小或finetune时跑飞情况分析_第4张图片

   上图的现象说明了当数据归一化到0-1之间时,该网络的中间部分几乎没有学习,参数缺少更新,为了验证这个想法,我分别输出迭代2000次时0-1输入时该网络的残差的均值和输入0-255时网络的残差的均值,分别为2.5x10^-3和0.26,差了两个数量级,验证了猜想,即输入为0-1时网络基本没有学到东西。
   由于网络的残差是在deconv10层,直接获得该层输出的代码(不通过参数返回的方式)如下所示:

activation = {}
def get_activation(name):
def hook(model, input, output):
    activation[name] = output.detach()
    return hook
model = self.net
model.deconv10.register_forward_hook(get_activation('deconv10'))
output = model(lr_imgs)
print('res:', activation['deconv10'].mean())

   现在通过结论寻找原因,当输出为output,网络权重为w,损失函数为L2 loss,目标为target,分析loss对权重更新也就是梯度的影响,如下式所示:

(outputtarget)2w=(outputtarget)2(output)(output)w=a(output)w ∂ ( o u t p u t − t a r g e t ) 2 ∂ w = ∂ ( o u t p u t − t a r g e t ) 2 ∂ ( o u t p u t ) ∂ ( o u t p u t ) ∂ w = a ∗ ∂ ( o u t p u t ) ∂ w

其中a是和output-target一个数量级的数值,由于output和target都是在0-1之间,a的值只能更小了,应该远远小于1。如果loss是L1 loss,那么a就是1,在 (output)w ∂ ( o u t p u t ) ∂ w 相同的情况下,L1 loss对梯度大小的贡献要大于L2 loss对梯度大小的贡献,更有利于中间部分get到梯度,输入为0-255时可同理分析得到,损失函数为L2时,a应该是一个比1远远大的值,损失函数为L1时,a会略微比1大。这样来看,对于梯度大小的贡献排序是L2(0-255)>L1(0-255)>L1(0-1)>L2(0-1)。
   我比较了一下输入为0-1时L2,L1 loss和输入为0-255时L2 loss和L1 loss的网络收敛的PSNR值,对比如下表,可以看出来上面的分析基本正确:

输入与Loss PSNR
L2(0-1) 30.4
L1(0-1) 33.24
L1(0-255) 33.97
L2(0-255) 34.05

   同理,因为初始学习率设置不合理,loss向上震荡后,输入为0-1,loss为L2 loss梯度回传受阻,所以此时finetune会导致loss无法回来(梯度回传受阻,参数无法有效更新),只能不断向上,从而出现跑飞的情况。
   上面只是我简短的对这种情况下loss训练异常的分析,如有不足请大家多多指正,多多交流。

你可能感兴趣的:(学习总结)