kaggle上模型的保存与读取----Pytorch框架

模型保存

torch.save(model.state_dict(),'./model_best.pth')

一般情况下,以上语句保存的模型在kaggle的output/kaggle/working文件夹下
建议将效果较好的模型下载保存,否则网页休眠之后之前训练的结果就都么得了

我这里使用的模型是经过预训练的resnet50

模型加载

model =  torchvision.models.segmentation.fcn_resnet50(pretrained= False,progress= True)
#model.classifier[4] = nn.Conv2d(512,1,kernel_size=(1,1),stride= (1,1))#之前训练对网络进行了修改
model.load_state_dict(torch.load(' ./model_best.pth '),strict=False)
model = model.to(device)#记得将模型转移到gpu

如果报错说找不到文件夹,检查你的路径是否正确

如有错误,还请指正

你可能感兴趣的:(深度学习,1024程序员节,pytorch,深度学习,神经网络,机器学习)