机器学习的Tricks:随机权值平均(Stochastic Weight Averaging,SWA)

SWA:随机权值平均(Stochastic Weight Averaging)
每次学习率循环结束时产生的局部最小值趋向于再损失面的边缘域
通过对这几个这样的点取平均,很有可能得到一个更低损失的全局化的通用解
此trick不牺牲inference latency
https://arxiv.org/pdf/1803.05407.pdf
机器学习的Tricks:随机权值平均(Stochastic Weight Averaging,SWA)_第1张图片
机器学习的Tricks:随机权值平均(Stochastic Weight Averaging,SWA)_第2张图片

torch中lr scheduler 设置如下 :
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=1, eta_min=1e-5, last_epoch=-1)
lr= optimizer.param_groups[-1][‘lr’]
scheduler.step(epoch + iteration + 1 / max_iter)

import os
import torch

def main():
    model_dir = 'model/path'
    save_dir = 'swa_model/resnet18'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_names = [2,5,8,11]  #设置12 epoch,提取lr=0时 epoch
    model_dirs = [
        os.path.join(model_dir, 'checkpoint-epoch' + str(i) + '.pth')
        for i in model_names
    ]
    print('model_dirs',model_dirs)
    models = [torch.load(model_dir) for model_dir in model_dirs]
    model_num = len(models)
    model_keys = models[-1]['state_dict'].keys()
    state_dict = models[-1]['state_dict']
    new_state_dict = state_dict.copy()
    ref_model = models[-1]

    for key in model_keys:
        sum_weight = 0.0
        for m in models:
            sum_weight += m['state_dict'][key]
        avg_weight = sum_weight / model_num
        new_state_dict[key] = avg_weight
    ref_model['state_dict'] = new_state_dict
    save_model_name = 'checkpoint-best.pth'
    save_dir = os.path.join(save_dir, save_model_name)
    torch.save(ref_model, save_dir)
    print('Model is saved at', save_dir)
if __name__ == '__main__':
    main()
    
    

你可能感兴趣的:(trick,深度学习)