Pytorch 保存 加载模型torch.save() torch.load()

Pytorch 保存和加载模型

Pytorch 保存和加载模型后缀:.pt 和.pth

保存整个模型:

torch.save(model,'save.pt')

只保存训练好的权重:

torch.save(model.state_dict(), 'save.pt')

加载模型:

pretrained_dict = torch.load("save.pt")

只加载模型参数:

model.load_state_dict(torch.load("save.pt"))  #model.load_state_dict()函数把加载的权重复制到模型的权重中去

加载某一层的训练的到的参数

conv1_weight_state = torch.load('save.pt')['conv1.weight']

你可能感兴趣的:(Pytorch)