【Pytorch1.1.0多GPU问题】TypeError: 'DataParallel' object is not iterable

使用多GPU运行Pytorch代码时,出现TypeError: 'DataParallel' object is not iterable

妈呀排查的我都要吐血了终于发现问题出在哪儿,还是我太白目。。

写的是CNN代码,写损失函数模型时,首先要把损失函数放到nn.ModuleList中,然后把它作为DataParallel的module参数传进去。

self.loss_module = nn.ModuleList() 
self.loss_module.append(nn.L1Loss())
self.loss_module = nn.DataParallel(
                self.loss_module, range(args.n_GPUs)
            )

之后在多GPU的情况下如果要以iterable的方式获取self.loss_module中的损失函数,就得拿其中的module属性!之前我就直接return self.loss_module了。。

for l in self.get_loss_module():    
    ...

def get_loss_module(self):
        if self.n_GPUs == 1:
            return self.loss_module
        else:
            return self.loss_module.module

 

 

你可能感兴趣的:(问题)