torch报错:StopIteration: Caught StopIteration in replica 0 on device 0.

pytorch DataParallel报错解决

  • 错误展示
  • 问题原因
  • 解决方法

错误展示

错误名称:

StopIteration: Caught StopIteration in replica 0 on device 0.

包版本:

pytorch-pretrained-bert 0.6.2
torch                   1.6.0

错误如下:

torch报错:StopIteration: Caught StopIteration in replica 0 on device 0._第1张图片
torch报错:StopIteration: Caught StopIteration in replica 0 on device 0._第2张图片

问题原因

使用单gpu的时候是正常的,但是使用多gpu的时候会报错。问题是多gpu进行模型训练的时候产生的,具体为,不能够用多gpu加载预训练的bert。应该是torch版本的问题。根据2可以知道,torch1.5版本有这个问题,我是torch1.6也有这个问题,据3替换为torch1.4可以解决该问题。

解决方法

比较简单粗暴的解决方法如下:
注意有如下问题:

  File "/miniconda/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py", line 727, in forward
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility

进入site-packages目录
/miniconda/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py 这个路径下的modeling.py脚本把727行的
next(self.parameters()).dtype换成torch.float32

你可能感兴趣的:(深度学习,bug,python,pytorch)