pytorch 计算图

例1

假设我们函数是 y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)(w+1),我们要求 y y y x 和 w x和w xw的导数,应该如何用pytorch来求解。
pytorch 计算图_第1张图片
上面的计算图表示 y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)(w+1),先计算括号内部的加法,再计算乘法。计算顺序是: a = x + w a=x+w a=x+w b = w + 1 b=w+1 b=w+1 y = a ∗ b y=a*b y=ab
用代码来表示:

import torch

w = torch.tensor([1.], requires_grad=False)  
x = torch.tensor([2.], requires_grad=True) 

a = torch.add(w, x)     # a = w + x
b = torch.add(w, 1)     # b = w + 1
y = torch.mul(a, b)     # y = a * b

# a.retain_grad()  用于保持非叶子节点的梯度
y.backward()    #反向传播求导
print(a.grad_fn)
print(x.grad) 
print(w.grad)
print(a.is_leaf, b.is_leaf, y.is_leaf, w.is_leaf, x.is_leaf)

'''

tensor([2.])
None
False True False True True
'''

pytorch创建的计算图分为叶子节点和非叶子节点。每一个节点都是一个tensor, tensor具有属性requires_grad(记录该tensor是否要求梯度),is_leaf(记录是否是叶子节点), grad_fn(记录创建该tensor的方法)。

  1. requires_grad: 如果创建这个tensor的输入中,至少有一个tensor的requires_grad=True,那么新创建的这个tensor的requires_grad=True。在上面这个例子中, a a a是由 w 和 x w和x wx相加得到的, x x xrequires_grad=True,所以 a a arequires_grad=True

  2. is_leaf:叶子节点是你所有手动创建的tensor,在这个例子中,叶子节点是 x , w 还 有 b x,w还有b xwb。注意,叶子节点的requires_grad并不一定是True。在本例中, w w w也是叶子节点,但是其requires_grad=False。再比如你创建的神经网络的参数是叶子节点,其requires_grad=True,比如你创建的模型的输入,虽然requires_grad=False,但是也是叶子节点。

  3. grad_fn:记录创建这个tensor的方法,比如本例中, a a agrad_fn就是AddBackward0,表示由加法得到。

还有一点需要注意,只有叶子节点的梯度在backward()之后是不被销毁的,非叶子节点的梯度在backward()之后是被销毁的,可以在y.backward() 之后打印 a a a的梯度试试。如果想保持飞叶子节点的梯度,在backward()之前,使用a.retain_grad()

例2

上面这个例子我们看到, b b b也是叶子节点,这似乎有点难以理解,如果我们再创建一个tensor, c = x + 1 c=x+1 c=x+1,那么 c c c是否是叶子节点?实践可知,c.is_leaf=False

例3

我们再举一个例子。观察下面的代码:

mport torch

w = torch.tensor([1.], requires_grad=False)  
x = torch.tensor([2.], requires_grad=True)  

print(x.is_leaf)
x = x + 1
# x.data.copy_(x.data+1)
print(x.is_leaf)
print(x.requires_grad)

# w = w+w

a = torch.add(w, x)     # a = w + x
b = torch.add(w, 1)     # b = w + 1
y = torch.mul(a, b)     # y = a * b

y.backward()    #对y进行反向传播

print(x.grad)   

'''
True
False
True
None
'''

上面代码表达的意思是,先对 x x x进行自加1再算 y y y。但是打印结果发现,在 x x x自加1之后, x x x就不是叶子节点了。

通过例2和例3是否可得到结论,对于requires_grad=True的叶子节点来说,对其做任何改动,得到的新的tensor都不是叶子节点了。这个结论为时过早,将x=x+1替换为x.data.copy_(x.data+1),可发现,x还是叶子节点。所以如果想对叶子节点的值进行改变,应该用copy_函数,而不是直接用等号改变。

tensor.data.copy_():只会改变tensor的data值,而不会改变is_leaf, requires_grad等其他属性值。

参考:pytorch——计算图与动态图机制

你可能感兴趣的:(pytorch)