使用model.module保存模型参数

问题再现

之前训练的很好的model(mAP=80)保存之后,在另一个文件里加载,结果效果很差劲(mAP=3);

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if is_best:
        torch.save(state, os.path.split(filename)[0] + '/model_best.pth.tar')
    else:
        torch.save(state, os.path.split(filename)[0] + filename)

if mAP_ema > mAP:
    mAP = mAP_ema
    state_dict = ema_m.module.state_dict()
else:
    state_dict = model.state_dict()

save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': state_dict,
            'best_mAP': best_mAP,
            'optimizer' : optimizer.state_dict(),
        }, is_best=is_best, filename=os.path.join(args.output, 'checkpoint.pth.tar'))

用 model.module 替代单独的 model

if mAP_ema > mAP:
    mAP = mAP_ema
    state_dict = ema_m.module.state_dict()
else:
    state_dict = model.module.state_dict()

保存模型,重新在另外一个模型加载,跑一遍validate(),最后结果也很棒;
所以这个方法是有效的;

参考

[1] Pytorch加载保存好的模型发现与实际保存模型的参数不一致

你可能感兴趣的:(#,读论文,写代码,pytorch)