遇到报错one of the variables needed for gradient computation has been modified by an inplace operation。意思是对输入x原地操作(inplace operation),一个变量在反向传播过程中被修改了,而不是按照预期的版本(version 0)更新,导致梯度不正确。
使用这句代码定位报错位置
torch.autograd.set_detect_anomaly(True)
定位到报错后可以修改代码,这是我原来的forword代码,可以看到 x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])这一句代码,将原来的x值原地替换,不能这样做
for i in range(self.num_dilation):
x[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C) #3,1,24,24,171-->1,24,24,513
x = self.proj(x)
x = self.proj_drop(x)
我们需要新建变量将这些值存起来最后赋值。新建一个列表,将值存入,最后使用cat统一(会丢失一个维度补上),然后就不报错了
x_i=[]
for i in range(self.num_dilation):
x_i.append(self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2]))
x =torch.cat(x_i,dim=0)#3,24,24,171
x = x.unsqueeze(1)#3,1,24,24,171
x = x.permute(1, 2, 3, 0, 4).reshape(B, H, W, C) #3,1,24,24,171-->1,24,24,513
x = self.proj(x)
x = self.proj_drop(x)