【深度学习】利用python画注意力热点图(heatmap)

目录

  • 概述
  • requirements
  • 代码实现
    • 1. 计算梯度图
    • 2. 计算热图
    • 3. 批量保存
  • 源代码
  • 参考

概述

跑完计算机视觉模型后,想要看特征图(注意力图)的激活区域对应原图的位置,则需要将激活度大于阈值的特征点标识出来,并将其上采样到原图大小,和原图按一定比例覆盖,可视化效果如下:
【深度学习】利用python画注意力热点图(heatmap)_第1张图片

requirements

  1. 环境:
  • torch >= 1.6
  • torchvision >= 0.8
  • numpy
  • opencv
  • timm
  • 若版本不对,可参考笔者使用torch = 1.8,torchvison = 0.9
  1. 配置文件

.yaml 格式的config文件
包含内容如下,【深度学习】利用python画注意力热点图(heatmap)_第2张图片

  1. 权重

想要验证的 .pth 模型权重,需要和选用的模型匹配

代码实现

1. 计算梯度图

def simple_grad_cam(features, classifier, target_class):
	# 得到特征图
    features = nn.Parameter(features) 
    # 计算分类结果
    logits = torch.matmul(features, classifier)
    # 回传
    logits[0, :, :, target_class].sum().backward()
    # 计算梯度
    features_grad = features.grad[0].sum(0).sum(0).unsqueeze(0).unsqueeze(0)
    # relu后归一化
    gramcam = F.relu(features_grad * features[0])
    gramcam = gramcam.sum(-1)
    gramcam = (gramcam - torch.min(gramcam)) / (torch.max(gramcam) - torch.min(gramcam))

    return gramcam

2. 计算热图

def get_heat(model, img):
    # 只需要前向传播过程
    with torch.no_grad():
        outs = model.forward_backbone(img.unsqueeze(0))
    
    # 得到特征
    features = []
    for name in outs:
        features.append(outs[name][0])
	
	# 每一层的权重
    layer_weights = [8, 4, 2, 1]
    # 初始化热图
    heatmap = np.zeros([args.data_size, args.data_size, 3])
    # 预处理每一张特征图
    for i in range(len(features)):
        f = features[i]
        f = f.cpu()
        if len(f.size()) == 2:
            S = int(f.size(0) ** 0.5)
            f = f.view(S, S, -1)
        # 调用simple_grad_cam计算梯度图
        gramcam = simple_grad_cam(f.unsqueeze(0), classifier=torch.ones(f.size(-1), 200)/f.size(-1), target_class=args.target_class)
        gramcam = gramcam.detach().numpy()
        # resize到原图大小
        gramcam = cv2.resize(gramcam, (args.data_size, args.data_size))
        # 热图颜色默认为红色
        heatmap[:, :, 2] += layer_weights[i] * gramcam
	# 按权重分配激活度
    heatmap = heatmap / sum(layer_weights)
    # 归一化
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
    heatmap[heatmap < args.threshold] = 0 # threshold
    # 从张量图变为RGB
    heatmap *= 255
    heatmap = heatmap.astype(np.uint8)

    return heatmap

3. 批量保存

if args.save_img_path != "":
	file_n = "mix" + os.path.basename(img_p)
	cv2.imwrite(os.path.join(args.save_img_path,file_n) , mix)

源代码

源代码

参考

https://github.com/chou141253/FGVC-PIM/

你可能感兴趣的:(CV&DL,python,深度学习)