GradCAM 神经网络注意力可视化

GradCAM 神经网络注意力可视化

GradCAM (GradCAM++)

开源工具:https://pypi.org/project/pytorch-gradcam/

论文地址:https://arxiv.org/pdf/1610.02391 https://arxiv.org/pdf/1710.11063

github:https://github.com/vickyliin/gradcam_plus_plus-pytorch

GradCAM 是针对图片分类网络的性能可视化分析工具。它通过将指定输出结果 backward 到指定网络层,获取网络得到该结果的注意力热力图,从而验证网络关注点是否合理。

Demo 程序

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from gradcam import GradCAM, GradCAMpp
from gradcam.utils import visualize_cam
from torchvision import transforms
import PIL
import matplotlib.pyplot as plt


def main():
    resnet = torchvision.models.resnet101(pretrained=True)
    resnet.eval()
    gradcam = GradCAM.from_config(model_type='resnet', arch=resnet, layer_name='layer4')

    # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
    pil_img = PIL.Image.open('snake.JPEG')
    torch_img = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])(pil_img).to('cpu')
    normed_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]

    # get a GradCAM saliency map on the class index 10.
    mask, logit = gradcam(normed_img, class_idx=None)

    # make heatmap from mask and synthesize saliency map using heatmap and img
    heatmap, cam_result = visualize_cam(mask, torch_img)

    plt.figure()
    plt.imshow(transforms.ToPILImage()(heatmap))

    plt.figure()
    plt.imshow(transforms.ToPILImage()(cam_result))

    plt.show()
    


if __name__ == '__main__':
    main()

GradCAM 通过传入特定网络模块名称(layer4)指定需要分析注意力的网络层;并通过 class_idx 来指定分析哪个类别注意力,如果传入 None,程序会自动分析当前图片最终分类结果的那个类别。

核心代码

class GradCAM(object):
    """Calculate GradCAM salinecy map.

    A simple example:

        # initialize a model, model_dict and gradcam
        resnet = torchvision.models.resnet101(pretrained=True)
        resnet.eval()
        gradcam = GradCAM.from_config(model_type='resnet', arch=resnet, layer_name='layer4')

        # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
        img = load_img()
        normed_img = normalizer(img)

        # get a GradCAM saliency map on the class index 10.
        mask, logit = gradcam(normed_img, class_idx=10)

        # make heatmap from mask and synthesize saliency map using heatmap and img
        heatmap, cam_result = visualize_cam(mask, img)
    """

    def __init__(self, arch: torch.nn.Module, target_layer: torch.nn.Module):
        self.model_arch = arch

        self.gradients = dict()
        self.activations = dict()
        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0]
        def forward_hook(module, input, output):
            self.activations['value'] = output

        target_layer.register_forward_hook(forward_hook)
        target_layer.register_backward_hook(backward_hook)

    @classmethod
    def from_config(cls, arch: torch.nn.Module, model_type: str, layer_name: str):
        target_layer = layer_finders[model_type](arch, layer_name)
        return cls(arch, target_layer)

    def saliency_map_size(self, *input_size):
        device = next(self.model_arch.parameters()).device
        self.model_arch(torch.zeros(1, 3, *input_size, device=device))
        return self.activations['value'].shape[2:]

    def forward(self, input, class_idx=None, retain_graph=False):
        """
        Args:
            input: input image with shape of (1, 3, H, W)
            class_idx (int): class index for calculating GradCAM.
                    If not specified, the class index that makes the highest model prediction score will be used.
        Return:
            mask: saliency map of the same spatial dimension with input
            logit: model output
        """
        b, c, h, w = input.size()

        logit = self.model_arch(input)
        if class_idx is None:
            score = logit[:, logit.max(1)[-1]].squeeze()
        else:
            score = logit[:, class_idx].squeeze()

        self.model_arch.zero_grad()
        score.backward(retain_graph=retain_graph)
        gradients = self.gradients['value']
        activations = self.activations['value']
        b, k, u, v = gradients.size()

        alpha = gradients.view(b, k, -1).mean(2)
        #alpha = F.relu(gradients.view(b, k, -1)).mean(2)
        weights = alpha.view(b, k, 1, 1)

        saliency_map = (weights*activations).sum(1, keepdim=True)
        saliency_map = F.relu(saliency_map)
        saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
        saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
        saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

        return saliency_map, logit

    def __call__(self, input, class_idx=None, retain_graph=False):
        return self.forward(input, class_idx, retain_graph)

初始化阶段

通过注册 hook 的方式,获取指定层 forward 的输出和 backward 的梯度输出

forward 阶段

  1. 基于传入的输入 forward 得到最终的分类结果,并基于 hook 获取指定层输出
  2. 针对选定的分类结果 backward,从而基于 hook 获取指定梯度结果
  3. 基于梯度计算权重,并对指定层输出结果加权,得到注意力 map
  4. 将注意力 map 通过 relu、upsample 和归一化得到最终的注意力可视化结果

检测网络尝试

GradCAM 原本是基于分类网络设计的,直接迁移到检测网络中对某个输出结果进行分析可以得到结果,但是最终结果并不能很好的反应注意力情况。

实际应用还需要进一步探索。

 

你可能感兴趣的:(深度学习)