pytorch中使用torchviz可视化某网络或loss函数计算图后,计算图节点的理解

一、安装graphviz之后,添加环境变量,可以用torchviz输出网络结构计算图

from torchviz import make_dot

make_dot(loss).view()

二、.backward() 方法

在 PyTorch 中,当在计算图上定义一系列操作后调用 .backward() 方法时,PyTorch 会为每个操作生成一个名为 Backward 的节点。这些名称是 PyTorch 为了跟踪反向传播的操作而生成的,是 PyTorch 内部生成的。

具体而言,名称的格式为 [OperatorName]Backward[Number],其中 OperatorName 表示操作的名称,Number 表示该操作在图中的位置。例如,SubBackward0 表示在图中第一个计算的减法操作;MmBackward0 表示第一个计算的乘法操作;UnsqueezeBackward0 表示第一个计算的 unsqueeze 操作。

除了上面的操作名称外,还有其他的节点类型。这取决于在图中定义的操作类型。对于每种操作,PyTorch 都会生成对应的反向传播节点,以支持自动求导。

三、计算图常见的节点类型名称:

  1. AddBackward:表示加法操作。

  1. SubBackward:表示减法操作。

  1. MulBackward:表示乘法操作。

  1. DivBackward:表示除法操作。

  1. ExpBackward:表示指数运算。

  1. LogBackward:表示对数运算。

  1. MatmulBackward:表示矩阵乘法。

  1. UnsqueezeBackward:表示添加一维的操作。

  1. SqueezeBackward:表示删除一维的操作。

  1. Conv2dBackward:表示卷积操作。

  1. MaxPool2dBackward:表示最大池化操作。

  1. StackBackward 表示堆叠操作,即将一组张量堆叠在一起,形成一个新的张量。

  1. TBackward 表示转置操作,即对一个张量进行转置。

  1. SelectBackward 表示选择操作,即选择张量的一部分作为新的张量。在 PyTorch 中,这可以通过索引(indexing)实现。

下图是自己定义简单MLP的torchviz可视化计算图(截取)

pytorch中使用torchviz可视化某网络或loss函数计算图后,计算图节点的理解_第1张图片

你可能感兴趣的:(pytorch学习积累,python,pytorch)