加载不同pytorch版本之间,训练得到的pkl文件

        我们将服务器上训练好的pkl文件下载到本地电脑上,经常会由于pytorch版本不统一问题,例如出现这种问题“_pickle.UnpicklingError: A load persistent id instruction was encountered...”

而无法直接加载,因此需要将pkl文件装成pth文件,这样不同版本的pytorch就可以互相加载

1. 在服务器端将pkl文件转成pth文件

import pickle as pkl
import torch

info_dict = torch.load('VGG_pre/VGG_4.pkl') #服务器端上的pkl保存位置
with open('pkl_model_vgg16.pth', 'wb') as f: #转成pth文件,并以pkl_model_vgg16.pth命名
    pkl.dump(info_dict, f)

2. 在本地端加载pth文件

import pickle as pkl

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net =  model.VGG16().eval()
net.to(device)  # 放到GPU上
with open(args.model_dir_test, 'rb') as f:
    info_dict = pkl.load(f)
net.load_state_dict(info_dict,strict=False)

 

你可能感兴趣的:(机器学习)