关于存储和加载模型权重

        checkpoint = load_checkpoint(args.resume)
        model_dict = model.state_dict()
        checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items() 
                            if k in model_dict}
        model_dict.update(checkpoint_load)
        model.load_state_dict(model_dict)
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(start_epoch, best_top1))

1.首先先读取arg.sume(已存储的权重)到checkpoint,相当于字典

2.再读取模型中的参数权重到model_dict

3.将checkpoint中key值对应model_dict的数据加载到checkpoint_load中

4.将已经训练好的模型参数更新并加载到已有模型参数中(单卡)

5.再读取checkpoint中的其他参数,以此类推

model.module.load_state_dict(checkpoint['state_dict'])

 加载模型参数(多卡)

torch.save(model.state_dict(), model_out_path)

存储模型参数(单卡)

torch.save(state, fpath)
save_checkpoint({
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1
                })

存储模型参数(多卡)以及其他信息,

你可能感兴趣的:(python)