pytorch加载预训练权重

resume_path 为 checkpoint.pth 的文件路径

checkpoint = torch.load(resume_path, map_location=torch.device(‘cpu’))

修改训练开始轮数

args.start_epoch = checkpoint['epoch']

获取预训练权重的参数

new_param = checkpoint['state_dict']

加载模型参数

model.load_state_dict(new_param)

加载优化器

optimizer.load_state_dict(checkpoint['optimizer'])

代码来源于DCP模型(cvpr22)的源代码

################### args.resume为checkpoint.pth文件 ###################
if args.resume:
    resume_path = osp.join(args.snapshot_path, args.resume)
    if os.path.isfile(resume_path):
        if main_process():
            logger.info("=> loading checkpoint '{}'".format(resume_path))
        ################### 加载预训练权重 ###################
        checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
        ################### 修改训练开始轮数 ###################
        args.start_epoch = checkpoint['epoch']
        ################### 获取预训练权重的参数 ###################
        new_param = checkpoint['state_dict']
        try: 
        	##################### 加载模型参数 ###################
            model.load_state_dict(new_param)
        except RuntimeError:                   # 1GPU loads mGPU model
            for key in list(new_param.keys()):
                new_param[key[7:]] = new_param.pop(key)
            model.load_state_dict(new_param)
        ##################### 加载优化器###################
        optimizer.load_state_dict(checkpoint['optimizer'])
        if main_process():
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
    else:
        if main_process():       
            logger.info("=> no checkpoint found at '{}'".format(resume_path))

你可能感兴趣的:(pytorch,深度学习)