mmdetection特征可视化V2

mmdetection特征可视化V2

  • 前言
  • 一、特征图可视化
    • 1.新建feature_visualization.py文件
    • 2.使用方法


前言

在上一篇博客中介绍了特征图可视化,发现还可以对其简化,不用修改一大堆东西,直接在我们想要可视化的地方直接调用可视化函数即可,方便大家在debug的时候可以快速的看到自己想要看的特征图

一、特征图可视化

1.新建feature_visualization.py文件

该文件我自己建立在tools文件夹下面,自己也可以新建一个文件夹放进去,里面主要包含两个函数,跟上一篇博客中的基本一样:

import cv2
import mmcv
import numpy as np
import os
import torch
import matplotlib.pyplot as plt


def featuremap_2_heatmap(feature_map):
    assert isinstance(feature_map, torch.Tensor)
    feature_map = feature_map.detach()
    heatmap = feature_map[:,0,:,:]*0
    heatmaps = []
    for c in range(feature_map.shape[1]):
        heatmap+=feature_map[:,c,:,:]
    heatmap = heatmap.cpu().numpy()
    heatmap = np.mean(heatmap, axis=0)

    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
    heatmaps.append(heatmap)

    return heatmaps

def draw_feature_map(features,save_dir = 'feature_map',name = None):
    i=0
    if isinstance(features,torch.Tensor):
        for heat_maps in features:
            heat_maps=heat_maps.unsqueeze(0)
            heatmaps = featuremap_2_heatmap(heat_maps)
            # 这里的h,w指的是你想要把特征图resize成多大的尺寸
            # heatmap = cv2.resize(heatmap, (h, w))  
            for heatmap in heatmaps:
                heatmap = np.uint8(255 * heatmap)
                # 下面这行将热力图转换为RGB格式 ,如果注释掉就是灰度图
                heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
                superimposed_img = heatmap
                plt.imshow(superimposed_img,cmap='gray')
                plt.show()
    else:
        for featuremap in features:
            heatmaps = featuremap_2_heatmap(featuremap)
            # heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同
            for heatmap in heatmaps:
                heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
                # heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
                # superimposed_img = heatmap * 0.5 + img*0.3
                superimposed_img = heatmap
                plt.imshow(superimposed_img,cmap='gray')
                plt.show()
                # 下面这些是对特征图进行保存,使用时取消注释
                # cv2.imshow("1",superimposed_img)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()
                # cv2.imwrite(os.path.join(save_dir,name +str(i)+'.png'), superimposed_img)
                # i=i+1

2.使用方法

将上述代码文件准备好后,后面的步骤就很简单了,直接在你想使用的地方直接调用函数即可,实例如下,比如我们用Faster_rcnn网络,就在two_stage.py文件里面,找到**extract_feat()**函数,增加两行代码,如下所示:

def extract_feat(self, img):
        """Directly extract features from the backbone+neck."""
        x = self.backbone(img)
        # 可视化resnet产生的特征
        from tools.feature_visualization import draw_feature_map
        draw_feature_map(x)
        if self.with_neck:
            x = self.neck(x)
            # 可视化FPN产生的特征
            from tools.feature_visualization import draw_feature_map
        	draw_feature_map(x)
        return x

用起来还是非常简单的,假如你用了其他的网络检测模型,需要在mmdet/models/detectors下面的文件中找到你所用的detector,这个在model的config文件看你model的type就可以查到

你可能感兴趣的:(可视化,python)