Pytorch 加载自己的训练模型

Pytorch保存我们训练好的模型,然后加载用于测试

第一种方法

(1)保存

torch.save(model.state_dict(), PATH)

# example
torch.save(resnet50.state_dict(),'ckp/model.pth')

(2)恢复

model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

#example
resnet=resnet50(pretrained=True)
resnet.load_state_dict(torch.load('ckp/model.pth'))

第二种方法

(1)保存

torch.save (model, PATH)

(2)恢复

model = torch.load(PATH)

你可能感兴趣的:(机器学习)