<input>:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed.

前言:

在调试计算模型梯度的时候突然蹦出::1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

(持续更新,)

1更:2022.08.30

一、问题分析

博主想得到loss计算环节某个变量的梯度,经过检查发现该变量的requires_grad()为True,但是在输出他的梯度的时候却出现上面的警告。

这个报错的大致翻译如下:

:1:UserWarning:正在访问非叶张量的张量的.grad属性。在autograd.backward()期间不会填充其.grad属性。如果确实需要非叶张量的梯度,请在非叶张量上使用.retain_grad()。如果您错误地访问了非叶张量,请确保您访问了叶张量。

一句话总结是:

并不是 requires_grad()为True就可以输出对应的梯度,还要看is_leaf属性,当is_leaf=false时,也即该变量是非叶张量,则会爆出上面的错误。

<input>:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed._第1张图片

二、什么是叶张量?

首先根据张量的属性判断,当张量的is_leaf为true时,该变量为叶张量

那么is_leaf是如何产生的:

由用户初始创建的变量,而不是程序中间产生的结果变量,那么该变量为叶变量。

如所示:

x1 = Variable(torch.ones(2, 2)*2,requires_grad=False)
print('x1是否是叶张量:',x1.is_leaf)
x2 = Variable(torch.ones(2, 2)*3, requires_grad=True)
print('x2是否是叶张量:',x2.is_leaf)

x3 = Variable(torch.ones(2, 2)*4, requires_grad=True)
y = x1*x3 + x2
t_1 = y.sum()
print('y是否是叶张量:',y.is_leaf)
print('t_1是否是叶张量:',t_1.is_leaf)

运行程序如下图所示:

y和t_1不是用户创建的,所以他们是非叶张量。

<input>:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed._第2张图片

 三、为什么要区分叶张量和非叶张量?

从以下几个方面来讲:

1.从链式求导法则来讲:

          对于一个可导函数,无论其多么复杂,一定可由若干可导函数的四则运算组成,那么在这个环节中任何一个运算都是可导的,最终得到函数的计算是需要这些环节的导数的传递。不知道讲清楚没有,,,例如:对上面的x1求导,那么自上而下要经过t_1,y,t_1和y的requires_grad()为True,否则梯度无法求解。

2.从求导的目的来讲

在深度学习或者机器学子中,求导或求梯度的根本原因是,对梯度进行反向传播,进而来更新相应的参数来达到学习的目的,在这个例子中可以看出y,t_1是前向传播计算得到的中间值,它们虽然需要求导但是不需要更新参数。

3.从实际需要出发

那些需要更新参数的变量,设置为叶张量,在计算梯度的时候,会保留它们的梯度信息,供反向传播更新参数的时候使用;那些不需要更新参数的中间变量,任然会计算与之相关联的梯度,但是程序不会保留该变量对应的梯度信息;这样有利于节约计算机的资源。

你可能感兴趣的:(pytorch,pytorch,深度学习,python)