pytorch 多GPU训练遇到的坑

在打算由单GPU转成多GPU时,使用:

self.model = self.model.cuda(device=device_ids[0])              
 self.model = torch.nn.DataParallel(self.model,device_ids=device_ids)

设置后,出现报错,报错显示读取不到模型的参数,原因是因为经过DataParallel包装过的模型如下:
pytorch 多GPU训练遇到的坑_第1张图片
和使用单GPU时不同的是多了一个.module,所以再进行任何需要调用model里面参数的操作时,都需要在model后面加上一个.module,即model.module,这样才能提取出model里面的参数以及函数等。如图所示:
pytorch 多GPU训练遇到的坑_第2张图片

你可能感兴趣的:(pytorch)