torch.load报错:No module named ‘models‘

使用pytorch训练模型时想要预先加载预训练模型,忽然出现这种错误。
原因大概是该预训练模型保存方法是完全保存:

torch.save(model, path)

该方法将模型内容全部保存,甚至包括存放路径
这导致将保存的模型换位置的后,load加载的时候可能导致路径出现问题

解决方法:

model = Model()
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, 'pretrained_model.pt')
torch.jit.load('pretrained_model.pt')

参考自

避免该问题的方法:
在保存模型的时候只保存状态字典,不要全部保存了!
即:

torch.save(model.state_dict(), PATH)

参考自

你可能感兴趣的:(深度学习,pytorch,预训练模型)