python工具方法35 实现SWA,再一次提升模型的性能

SWA是论文Averaging Weights Leads to Wider Optima and Better Generalization所提出的一种无痛涨点的方式,只需要在模型训练的最后阶段保存模型权重,然后取模型权重的平均值,就可以提升模型的权重。按照论文描述,针对不同的模型基本上可以涨一个点。论文:SWA Object Detection详细描述了实验SWA后,模型的涨点效果。

SWA的论文翻译:https://github.com/timgaripov/swa
SWA的项目地址:https://github.com/timgaripov/swa

为此,博主根据论文描述和SWA作者公布的源码,仿照ema的模型增强技术代码,重新实现了swa。这里的实现支持torch、paddle(博主亲测,tf2模型也应该是支持的,只是要修改权重加载与保存的部分)。这里的实现是针对模型权重,对于pytorch的mmdetection框架,paddle的paddledetection框架中的模型都是支持的。博主亲测,用swa提升了0.5的map。

按照swa论文所述,当模型带bn层时,swa_model中的bn层参数需要重新更新。因此,博主刻意实现了一个forward函数,用于更新bn层的参数【针对mmdetection、paddledetection等框架时无效】。

1、SWA实现

博主这里实现的SWA支持在训练过程中使用,也支持在模型训练完成后选择模型进行权重平均。
针对于用户只需要关注两个函数update和smooth_dir。update用于在训练过程中调用(在合适的epoch中[epoch数大于budget时]进行权重平均),smooth_dir用于在模型训练

你可能感兴趣的:(python工具方法)