Pytorch多GPU训练

Pytorch多GPU训练

1. torch.nn.DataParallel

torch.nn.DataParallel()这个主要适用于单机多卡。个人一般比较喜欢在程序开始前,import包之后使用os.environ['CUDA_VISIBLE_DEVICES']来优先设定好GPU。例如要使用物理上第0,3号GPU只要在程序中设定如下:

os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'

**注意:**如上限定物理GPU后,程序实际上的编号默认为device_ids[0],device_ids[1]。就是说程序所使用的显卡编号实际上是经过了一次映射之后才会映射到真正的显卡编号上面的。所以device_ids这个参数后续就不用再另行设置了。

batch_size设定

batch——size的大小应该大于所使用的GPU的数量。还应当是GPU个数的整数倍,这样划分出来的每一块都会有相同的样本数量。现batch_size = 原batch_size * num_GPUs

加载模型

model = nn.DataParallel(model)
model = model.cuda()

当然直接指定device_ids也可以:

net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
model = model.cuda()

加载数据

inputs = inputs.cuda()
labels = labels.cuda()

注意点

注意:nn.DataParallel(model)这句返回的已经不是原始的m了,而是一个DataParallel,原始的m保存在DataParallel的module变量里面。解决方法:

保存模型

  • 保存的时候就取出原始model:
torch.save(model.module.state_dict(), path)
  • 或者载入的时候用一个DataParallel载入,再取出原始模型:
model = nn.DataParallel(Resnet18())
model.load_state_dict(torch.load(path))
model = model.module

优化器

在训练过程中,你的优化器同样可以使用nn.DataParallel,如下两行代码:

optimizer = torch.optim.SGD(net.parameters(), lr=lr)
optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
# 优化器使用:
optimizer.step() --> optimizer.module.step()

Warning

UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars;
will instead unsqueeze and return a vector.

关于此的讨论:

https://github.com/pytorch/pytorch/issues/9811


torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

参数说明:

  • module – 要被并行的module
  • device_ids – CUDA设备,默认为所有设备。所以为了省事一般在前面加os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'
  • output_device – 输出设备(默认为device_ids[0]) 。所以所使用的0号卡,显存占用总是比较高。 负载不均衡很严重的话,建议使用DistributedDataParallel

此容器通过将mini-batch划分到不同的设备上来实现给定module的并行。在forward过程中,module会在每个设备上都复制一遍,每个副本都会处理部分输入。在backward过程中,副本上的梯度会累加到原始module上。

dataparallel只是数据input被分到不同卡上,模型还是只在device0上的.首先各个卡只计算到loss,然后0号卡做loss平均,最后分发到各个卡上求梯度并进行参数更新。

Reference:

OPTIONAL: DATA PARALLELISM

PyTorch官方中文

pytorch 多 gpu 并行训练
https://blog.csdn.net/qq_34243930/article/details/106695877
https://zhuanlan.zhihu.com/p/86441879
https://zhuanlan.zhihu.com/p/102697821

你可能感兴趣的:(计算机视觉,Pytorch,pytorch,GPU)