StopIteration: Caught StopIteration in replica 0 on device 0. 问题排查与解决

首先是错误内容截图:(抱歉因为打码有点糊)

StopIteration: Caught StopIteration in replica 0 on device 0. 问题排查与解决_第1张图片我在训练修改后的TransformerXL时,发现了如上的错误,此前代码已经成功地在单GPU下运行过,切换到多卡运行出现该问题。尝试进行解决。

使用的环境是: Pytorch1.11 transformers:4.18

在网上进行查阅后大部分人都说可能是pytorch版本的问题,当前所使用的pytorch版本过高,需要降级到1.4.0版本。

降级听起来比较简单,但是我不想降级到太低的版本,只能走第二条路,修改代码。

首先定位到出错的非源码的最后一行,

param = next(self.parameters())

经过上网查找,发现可能是在训练过程中部分数据的精度不同导致的问题,可能同时存在16位精度和32位精度的数据,尝试在这里进行修改,将其直接指定为torch.float32 进行训练。

原始代码为:

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer+1):
                empty = torch.empty(0, dtype=param.dtype, device=param.device)
                mems.append(empty)
            return mems
        else:
            return None

更改后的代码是: 

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer+1):
                empty = torch.empty(0, dtype=torch.float32).cuda()
                mems.append(empty)
            return mems
        else:
            return None

成功! 问题解决! 

你可能感兴趣的:(Pytorch,pytorch,人工智能,深度学习)