Pytorch nn.DataParallel()的简单用法

简单来说就是使用单机多卡进行训练。
一般来说我们看到的代码是这样的:

net = XXXNet()
net = nn.DataParallel(net)

这样就可以让模型在全部GPU上训练。

方法定义:

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
  • module:需要进行并行的模型
  • device_ids:并行所用的GPU。可以是int列表也可以是device对象,默认不写就是使用全部GPU
  • output_device:输出所用的GPU。可以是GPU id或device对象,默认不写就是第一张(device_ids[0])

参考

https://www.aiuai.cn/aifarm1340.html
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/

你可能感兴趣的:(Pytorch)