【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:
请添加图片描述
  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:
【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_第1张图片
  当输入序列 x x x经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列 x x x上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context[:, -3:] = extra
        mask[:, -3:] = 1

        d2 = self.bert(context, attention_mask=mask)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_第2张图片
  图中 w 1 w_1 w1的梯度计算如上图,损失函数为 E t o t a l E_{total} Etotal,最终 w 1 w_1 w1的梯度里是需要用到原始输入 i 1 i_1 i1的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context_B = copy.deepcopy(context)
        mask_B = copy.deepcopy(mask)

        context_B[:, -3:] = extra
        mask_B[:, -3:] = 1

        d2 = self.bert_A(context_B, attention_mask=mask_B)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2

你可能感兴趣的:(机器学习,深度学习,人工智能,梯度更新,反向传播)