pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练

pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练_第1张图片
如图是pytorch采用计算图来求解线性方程: y ( h , x ) = W h ∗ h + W x ∗ x y(h,x)=W_{h}*h+W_{x} *x y(h,x)=Whh+Wxx其中‘→’的方向为反向传播的方向。然而一般情况下当反向传播backward()结束时代表计算图的一次迭代就结束了,此时计算图会自动free掉。但在我们的实验过程中,常常需要设计复杂的损失函数以取得我们所需要的显著的实验效果。如下图:
在这里插入图片描述
pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练_第2张图片
pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练_第3张图片
pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练_第4张图片
两个损失函数是截然不同的两类损失函数,因此我们可以通过代码:backward(retain_graph=True)在计算出第一个损失函数的梯度值后保存计算图用于继续计算第二个损失函数的梯度。
pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练_第5张图片

你可能感兴趣的:(pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练)