torch叶子节点才能保存grad,叶子节点如何修改才不变为中间节点,保留grad呢?使用data

#梯度存储在自变量中,grad属性中
import torch

x = torch.tensor([3.0,5],requires_grad=True)#x设为可以求梯度,由他生成的变量均可求导

x1=torch.tensor([3.7,4])#默认不可求梯度

y = x ** 2+x1  #x是叶子节点,y是非叶子节点,backward()后y.grad_fn, y.grad不存在


z=y[0]+y[1]

#判断x,y,z是否是可以求导的
print("x1:",x1.requires_grad)
print("x:",x.requires_grad)
print("y:",y.requires_grad)

z.backward()  #反向求导
##print("y.grad",y.grad()) #y非叶子节点 grad不存在 报错
print("x.grad",x.grad)

x=x+1  #x变为非叶子节点 grad将不存在-------------替换为:x.data=x.data+1
#要改变x的值可以用 x.data=x.data+1,x还是叶子节点还能保存grad

print(x.grad)

你可能感兴趣的:(torch叶子节点才能保存grad,叶子节点如何修改才不变为中间节点,保留grad呢?使用data)