torch.load received a zip file

torch.load received a zip file

 

cltt

torch.load()加载模型,提示xxx.pt is a zip archive(did you mean to use torch.jit.load()?)

参考链接

https://blog.csdn.net/irober/article/details/115144522?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-0&spm=1001.2101.3001.4242

torch版本问题

import torch
state_dict = torch.load("./model_00050.pt")#加载原来的模型
torch.save(state_dict, "./00050.pt", _use_new_zipfile_serialization=False)#不是zip 

上面这个报错,说是用这个; 


#在torch 1.6版本中重新加载一下网络参数
model = MyNetwork().to(device) #实例化模型并加载到cpu货GPU中
model.load_state_dict(torch.load(model_cp))  #加载模型参数,model_cp为之前训练好的模型参数(zip格式)
#重新保存网络参数,此时注意改为非zip格式
torch.save(model.state_dict(), model_cp,_use_new_zipfile_serialization=False)

 

你可能感兴趣的:(pytorch知识宝典)