pytorch torch: slice赋值以及clone不会截断梯度

下面代码tensor b的梯度能不能正常传递给a

a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()
b = a.new_full((4,), 0)
b[[1, 2, 3]] = a 

答案是可以。

  • 下面的前两个输出表明赋值是深拷贝,不是浅拷贝。a,b是两个不同的内存
  • 最后两个输出表明,b不被重新赋值的部分能够将梯度反传回a。
import torch

a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()

b = a.new_full((4,), 0)
b[[1, 2, 3]] = a  # 这里写 b[[1, 2, 3]] = a.clone() 结果也是一样的
b[1] = 100
print(a, b)
# tensor([1., 2., 3.], requires_grad=True)
# tensor([  0., 100.,   2.,   3.], grad_fn=)

((b - 1.5) ** 2).sum().backward()
print(a.grad)
# tensor([0., 1., 3.])

a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()
((a - 1.5) ** 2).sum().backward()
print(a.grad)
# tensor([-1.,  1.,  3.])

pytorch里面,clone, 赋值都是可导的,梯度是不会被截断的,只有detach才会截断。

你可能感兴趣的:(pytorch,python)