Pytorch多GPU训练DataParallel的使用

Pytorch官网有个简单的示例

https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

其实用起来还是比较简单的,大致如下:

from torch.nn import DataParallel


model = model.cuda()
model = DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

# AttributeError: 'DataParallel' object has no attribute XXX
model.module.XXX

 

你可能感兴趣的:(Pytorch,DataParallel,多GPU训练)