gpu_ids = [0, 1, 2, 3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能单GPU运行
net = LeNet()
if len(gpu_ids) > 1:
net = nn.DataParallel(net, device_ids=gpu_ids)
net = net.to(device)
由于多GPU训练使用了 nn.DataParallel(net, device_ids=gpu_ids) 对网络进行封装,因此在原始网络结构中添加了一层module。网络结构如下:
DataParallel(
(module): LeNet(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
)
而不使用多GPU训练的网络结构如下:
LeNet(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
if len(gpu_ids) > 1:
t.save(net.module.state_dict(), "model.pth")
else:
t.save(net.state_dict(), "model.pth")
或者写入字典
def save_model (model, cudan=4):
savepath = str(dir_checkpoint) + '/best_model.pth'
# 定义要保存的模型的字典
state = {
'epoch': nb + 1,
'mIoU': newmIoU,
'dev_loss': dev_loss,
"lr:":lr,
# 'model_state_dict': model.module.state_dict(), # 保存多GPU网络模型的字典
# 'model_state_dict': model.state_dict(), # 保存单GPU模型的字典
'optimizer_state_dict': optimizer.state_dict(),
}
# 保存网络模型 https://blog.csdn.net/anshiquanshu/article/details/122157157
if cudan > 1: # 并行的保存
state['model_state_dict'] = model.module.state_dict() # 多GPU
else:
state['model_state_dict'] = model.state_dict() # 单GPU模型的字典
torch.save(state, savepath)
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint["model_state_dict"].items()})
model = nn.DataParallel(model).cuda()
或者字典中的加载
model2 = net()
model2.load_state_dict({k.replace('module.', ''):v for k, v in torch.load('demo.pth').items()})
model2 = nn.DataParallel(model2).cuda()
参考链接:
[1]https://blog.csdn.net/anshiquanshu/article/details/122157157
[2] https://www.jb51.net/article/189297.htm