Pytorch 中的 torch.optim.swa_utils.AverageModel() 及其原理总结

1 背景知识

在了解 torch.optim.swa_utils.AverageModel() 前, 我们先了解以下 SWA(随机加权平均)

1.1 SWA

SWA 全称 : Stochastic Weight Averaging,

  • SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果

  • 随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点, 经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失

  • SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好

2 AverageModel() 介绍

AveragedModel 类用于计算SWA模型的权重。可以通过运行以下命令创建一个averaged model:

from torch.optim.swa_utils import AverageModel
swa_model = AverageModel(model)

这里的模型Model可以是任意的torch.nn.Module对象。swa_model将跟踪模型参数的运行平均值。要更新这些平均值,你可以使用update_parameters()函数:

swa_model.update_parameters(model)

你可能感兴趣的:(Pytorch,中的各种函数,Pytorch)