pytorch的计算图 loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放

前言:

接触pytorch这几个月来,一开始就对计算图的奥妙模糊不清,不知道其内部如何传播。这几天有点时间,就去翻阅了Github,pytorch Forum,还有很多个人博客(后面会给出链接),再加上自己的原本一些见解,现在对它的计算图有了更深层次的理解。

pytorch是非常好用和容易上手的深度学习框架,因为它所构建的是动态图,极大的方便了coding and debug。可是对于初学者而言,计算图是一个需要深刻理解的概念,在后期的搭建的神经网络都是基于计算图而设计的。

一、构建计算图

pytorch是动态图机制,所以在训练模型时候,每迭代一次都会构建一个新的计算图。而计算图其实就是代表程序中变量之间的关系。举个列子: 在这个运算过程就会建立一个如下的计算图:

pytorch的计算图 loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放_第1张图片 pytorch的计算图 loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放_第2张图片

在这个计算图中,节点就是参与运算的变量,在pytorch中是用Variable()变量来包装的,而图中的边就是变量之间的运算关系,比如:torch.mul(),torch.mm(),torch.div() 等等。

注意图中的 leaf_node,叶子结点就是由用户自己创建的Variable变量,在这个图中仅有a,b,c 是 leaf_node。为什么要关注leaf_node?因为在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,如下就是网络的求导过程。

pytorch的计算图 loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放_第3张图片 pytorch的计算图 loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放_第4张图片

二、图的细节

pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

net = nn.Linear(3, 4)  # 一层的网络,也可以算是一个计算图就构建好了
input = Variable(torch.randn(2, 3), requires_grad=True)  # 定义一个图的输入变量
output = net(input)  # 最后的输出
loss = torch.sum(output)  # 这边加了一个sum() ,因为被backward只能是标量
loss.backward() # 到这计算图已经结束,计算图被释放了

上面这个程序是能够正常运行的,但是下面就会报错

net = nn.Linear(3, 4)
input = Variable(torch.randn(2, 3), requires_grad=True)
output = net(input)
loss = torch.sum(output)

loss.backward()
loss.backward()

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

之所以会报这个错,因为计算图在内存中已经被释放。但是,如果你需要多次backward只需要在第一次反向传播时候添加一个标识,如下:

net = nn.Linear(3, 4)
input = Variable(torch.randn(2, 3), requires_grad=True)
output = net(input)
loss = torch.sum(output)
loss.backward(retain_graph=True) # 添加retain_graph=True标识,让计算图不被立即释放
loss.backward()

这样在第一次backward之后,计算图并不会被立即释放。

读到这里,可能你对计算图中的backward还是一知半解。例如上面提过backward只能是标量。那么在实际运用中,如果我们只需要求图中某一节点的梯度,而不是整个图的,又该如何做呢?下面举个例子,列子下面会给出解释。

x = Variable(torch.FloatTensor([[1, 2]]), requires_grad=True)  # 定义一个输入变量
y = Variable(torch.FloatTensor([[3, 4],
[5, 6]]))
loss = torch.mm(x, y) # 变量之间的运算
loss.backward(torch.FloatTensor([[1, 0]]), retain_graph=True) # 求梯度,保留图
print(x.grad.data) # 求出 x_1 的梯度
x.grad.data.zero_() # 最后的梯度会累加到叶节点,所以叶节点清零
loss.backward(torch.FloatTensor([[0, 1]])) # 求出 x_2的梯度
print(x.grad.data) # 求出 x_2的梯度

结果如下:

3  5
[torch.FloatTensor of size 1x2]

4 6
[torch.FloatTensor of size 1x2]

可能看到上面例子有点懵,用数学表达式形式解释一下,上面程序等价于下面的数学表达式:

这样我们就很容易利用backward得到一个雅克比行列式:

[公式]

到这里应该对pytorch的计算图和backward有一定了解了吧。

如有错误,欢迎指正。


References:

Calculus on Computational Graphs: Backpropagation

Computational Graphs in PyTorch

PyTorch中的backward

你可能感兴趣的:(pytorch)