基于pytorch的特征图可视化

前言

在利用深度学习进行分类时,有时需要对中间的特征图进行可视化操作,看看网络都学习了哪些东西。本篇博文将简单介绍下,可视化操作。

网络模型部分

主要是forward部分,简单处理下。

def forward(self, x):
    outputs = []
    conv0 = self.encoder.conv1(x)
    conv0 = self.encoder.bn1(conv0)
    conv0 = self.encoder.relu(conv0)

    conv1 = self.pool(conv0)
    conv1 = self.conv1(conv1)
    outputs.append(conv1)

    conv2 = self.conv2(conv1)
    outputs.append(conv2)

    conv3 = self.conv3(conv2)
    outputs.append(conv3)
    return outputs

可视化操作

# load image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img = cv.imread("./data/JAX_Tile_019__41.png")
tensor = img_to_tensor(img)
tensor = Variable(torch.unsqueeze(tensor, dim=0).float(), requires_grad=False)

# forward
model.eval()
out_put = model(tensor.to(device))
for feature_map in out_put:
    # [N, C, H, W] -> [C, H, W]
    im = np.squeeze(feature_map.detach().cpu().numpy())
    # [C, H, W] -> [H, W, C]
    im = np.transpose(im, [1, 2, 0])

    # show top 12 feature maps
    plt.figure()
    for i in range(16):
        ax = plt.subplot(4, 4, i+1)
        # [H, W, C]

        cmap = 'nipy_spectral'
        plt.imshow(im[:, :, i], cmap=plt.get_cmap(cmap))
    plt.show()

原图
基于pytorch的特征图可视化_第1张图片
第一层特征图可视化:
基于pytorch的特征图可视化_第2张图片
第二层:
基于pytorch的特征图可视化_第3张图片
第三层:
基于pytorch的特征图可视化_第4张图片

你可能感兴趣的:(pytorch)