pytorch网络可视化-查看每层网络的特征图

pytorch网络可视化-查看每层网络的特征图

当今发论文离不开图像的可视化

于是乎我就业余时间打算学习一下如何可视化每一层的图像

一般来说会选取tensorboard来进行可视化

这里采用另外一种方式

  def forward(self, x):
        outputs = []
        x = self.conv1(x)
        outputs.append(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        outputs.append(x)
        x = self.layer2(x)
        outputs.append(x)
        x = self.layer3(x)
        outputs.append(x)
        x = self.layer4(x)
        outputs.append(x)

        return outputs

这是resnet的网络我们将每一层的输出添加进入outputs里面方便后面可视化

可视化关键代码


# forward
out_put = model(img)
#获取输出列表这是一个列表,里面每个代表了每层的输出
for feature_map in out_put:
    #通过一个迭代器来遍历每个特征图
    # [N, C, H, W] -> [C, H, W]
    im = np.squeeze(feature_map.detach().numpy())#把tensor变成numpy
    # [C, H, W] -> [H, W, C]
    im = np.transpose(im, [1, 2, 0])
	#对图像的通道进行处理
    # show top 12 feature maps
    plt.figure()
    for i in range(12):
        ax = plt.subplot(3, 4, i+1)
        # [H, W, C]
        plt.imshow(im[:, :, i], cmap='gray')
    plt.show()

原图pytorch网络可视化-查看每层网络的特征图_第1张图片
**

第一层

**
pytorch网络可视化-查看每层网络的特征图_第2张图片

第二层

pytorch网络可视化-查看每层网络的特征图_第3张图片

第三层

pytorch网络可视化-查看每层网络的特征图_第4张图片

第四层

pytorch网络可视化-查看每层网络的特征图_第5张图片

第五层

pytorch网络可视化-查看每层网络的特征图_第6张图片

你可能感兴趣的:(动手学深度学习笔记)