下面代码tensor b的梯度能不能正常传递给a
a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()
b = a.new_full((4,), 0)
b[[1, 2, 3]] = a
答案是可以。
import torch
a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()
b = a.new_full((4,), 0)
b[[1, 2, 3]] = a # 这里写 b[[1, 2, 3]] = a.clone() 结果也是一样的
b[1] = 100
print(a, b)
# tensor([1., 2., 3.], requires_grad=True)
# tensor([ 0., 100., 2., 3.], grad_fn=)
((b - 1.5) ** 2).sum().backward()
print(a.grad)
# tensor([0., 1., 3.])
a = torch.tensor([1, 2, 3]).float()
a.requires_grad_()
((a - 1.5) ** 2).sum().backward()
print(a.grad)
# tensor([-1., 1., 3.])
pytorch里面,clone, 赋值都是可导的,梯度是不会被截断的,只有detach才会截断。