pytorch : Stochastic Weight Averaging理解和用法

SWA has been proposed in Averaging Weights Leads to Wider Optima and Better Generalization.

SGD倾向于收敛到loss的平坦的区域,由于权重空间的维度比较高,平坦区域的大部分都处于边界,SGD通常只会走到这些平稳区域的边界。SWA通过平均多个SGD的权重参数,使其能够达到平坦区域的中心,从而得到更优的解。就相当于在一个最优解的附近,梯度都很小了,迭代可能一之在最优解附近震荡,通过平均最优解周围的参数,可以更接近最优解,利用平均值中和误差;

第一种实现方法:就是最后几次的平均

swa_model = AveragedModel(model)
swa_model.update_parameters(model) # 最后来一次就行,平均一下就可

第二种:一边平均,一边学习率变化

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

注意事项:

权重更新完了之后,需要来一下这个,对batchnormalization

torch.optim.swa_utils.update_bn(loader, swa_model

BN层训练过程中计算激活神经元的统计信息,而SWA平均的权重在训练过程中是不会用来预测的,所以当权重更新之后,BN层相对应的统计信息仍然是之前权重的。为了计算激活值的统计信息,只需要在训练结束之后对训练数据前向传播一次即可。

自定义平均:

默认情况下,torch.optim.swa_utils.AveragedModel计算参数的运行平均值,但也可以将自定义平均值函数与avg_fn参数一起使用。在以下示例中,ema_模型计算指数移动平均数。

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

参考文章:

1、https://liumin.blog.csdn.net/article/details/113128343 (随机梯度平均)

2、https://zhuanlan.zhihu.com/p/122504469

你可能感兴趣的:(pytorch,深度学习,计算机视觉)