pytorch-lightning的一些记录

  • 收集每个GPU上的输出
    在分布式训练时,每个GPU都会有一部分数据,当我们需要使用全部的数据进行计算时,我们需要收集所有GPU的tensor。
    比如两个GPU,第一个GPU有16组数据,第二个GPU有16组数据, 在进行对比学习计算时,我们需要收集所有的输出来增加负样本的数量。
    我们可以使用tensors_from_all = self.all_gather(my_tensor)
    比如:
    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        ...

        all_outputs = self.all_gather(outputs, sync_grads=True)

        loss = contrastive_loss_fn(all_outputs, ...)
        return loss

你可能感兴趣的:(pytorch-lightning的一些记录)