pytorch-detach

 1 import torch
 2 from torch.autograd import Variable
 3 
 4 torch.random.manual_seed(1)
 5 w1 = torch.Tensor([2])  # 认为w1 与 w2 是函数f1 与 f2的参数
 6 print('w1',w1)
 7 w1 = Variable(w1, requires_grad=True)
 8 
 9 w2 = torch.Tensor([2])
10 w2 = Variable(w2, requires_grad=True)
11 x2 = torch.rand(1)
12 x2 = Variable(x2, requires_grad=True)
13 
14 y2 = x2 ** w1  # f1 运算
15 z2 = w2 * y2 + 1  # f2 运算
16 # z2 = z2.detach()  #截断了反向传播的梯度流
17 z2.backward()
18 # z2 = z2.detach()  # 对反向传播并没有影响
19 print(x2.grad)
20 print(y2.grad)
21 print(w1.grad)
22 print(w2.grad)

detach放在backward后并没有影响

你可能感兴趣的:(pytorch-detach)