解决报错torch._C._gather(tensors, dim, destination) RuntimeError: Input tensor at index 1 has invalid

在pytorch的分布式训练中,每个卡都会有一个模型(replicate步骤),以及分配的输入(scatter步骤),最后再把每个模型的输出合并(gather步骤),如果每个模型输出的维度不一致的话,是无法gather的。

因此,查看模型return的值,确实是在根据场景实时变化的。其会根据各个样本中具体场景而发生变化,而不同的卡上输出tensor维度不一样,所以无法gather。

报错虽然出现在底层,但是问题本身还是在于模型。在改掉变化的部分之后能够正常运行。

你可能感兴趣的:(PyG)