pytorch模型载入之gpu和cpu互转

gpu转cpu

model = ModelArch(para)   # 网络结构
model = torch.nn.DataParallel(model, device_ids=[0]).cuda()  # 将model转为gpu模式
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) # 载入weights
model.load_state_dict(checkpoint)  # 用weights初始化网络
cpu_model = model.module           # 转为cpu模式
# cpu模型存储, 注意这里的state_dict后的()必须加上,否则报'function' object has no attribute 'copy'错误
torch.save(cpu_model.state_dict(), 'cpu_mode.pth') 

cpu转gpu

model = ModelArch(para)   # 网络结构
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(0)) # 载入weights
model.load_state_dict(checkpoint)  # 用weights初始化网络
torch.save(model.state_dict(), 'gpu_mode.pth') 

 

你可能感兴趣的:(pytorch,模型载入,state_dict)