torch.nn.DataParallel保存,单GPU加载

torch.nn.DataParallel是torch多GPU训练的设置
torch.nn.DataParallel保存,单GPU加载

1.torch.nn.DataParallel 保存,多GPU加载
保存
torch.nn.DataParallel(model).cuda() 含有module参数
torch.save(model.state_dict(), ‘demo.pth’)

加载
torch.nn.DataParallel(model).cuda()
checkpoint = torch.load(‘demo.pth’)
model.load_state_dict[checkpoint]

2.torch.nn.DataParallel保存,单GPU加载
加载
checkpoint = torch.load(‘demo.pth’)
model.load_state_dict({k.replace(‘module.’,‘’):v for k,v in torch.load(‘demo.pth’).items()})

3.单GPU保存,单GPU加载
保存
torch.save(model.state_dict(), ‘demo.pth’)
加载
checkpoint = torch.load(‘demo.pth’)
model.load_state_dict[checkpoint]

ps:model.load_state_dict({k.replace(‘module.’,‘’):v for k,v in torch.load(‘demo.pth’).items()})这种方式在单GPU保存,单GPU加载中使用,实际运行结果也没有问题,key中本身没有mudule,replace也无所谓,但是通常还是使用model.load_state_dict[checkpoint]方式

你可能感兴趣的:(torch,pytorch,深度学习)