pytorch如何加载部分模型参数

使用openpose pytorch版本查看中间热力图结果,需要加载部分参数,过程如下

1.把模型的结构加载进来 

pretrained_dict = torch.load(model_body25)
model = bodypose_25_model()

2.通过字典形式,加载网络中的部分参数 

model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

 3.转移cuda,改成eval模式,如果模型中有relu或BN层,切记一定要加eval(),否则很有可能每次预测的结果都不一样或达不到预期

# if torch.cuda.is_available():
#     model = model.cuda()
model = model.cuda()
model.eval()

4.最终效果

 pytorch如何加载部分模型参数_第1张图片

pytorch如何加载部分模型参数_第2张图片

你可能感兴趣的:(常用工具,深度学习,神经网络)