GradCAM (GradCAM++)
开源工具:https://pypi.org/project/pytorch-gradcam/
论文地址:https://arxiv.org/pdf/1610.02391 https://arxiv.org/pdf/1710.11063
github:https://github.com/vickyliin/gradcam_plus_plus-pytorch
GradCAM 是针对图片分类网络的性能可视化分析工具。它通过将指定输出结果 backward 到指定网络层,获取网络得到该结果的注意力热力图,从而验证网络关注点是否合理。
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from gradcam import GradCAM, GradCAMpp
from gradcam.utils import visualize_cam
from torchvision import transforms
import PIL
import matplotlib.pyplot as plt
def main():
resnet = torchvision.models.resnet101(pretrained=True)
resnet.eval()
gradcam = GradCAM.from_config(model_type='resnet', arch=resnet, layer_name='layer4')
# get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
pil_img = PIL.Image.open('snake.JPEG')
torch_img = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])(pil_img).to('cpu')
normed_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
# get a GradCAM saliency map on the class index 10.
mask, logit = gradcam(normed_img, class_idx=None)
# make heatmap from mask and synthesize saliency map using heatmap and img
heatmap, cam_result = visualize_cam(mask, torch_img)
plt.figure()
plt.imshow(transforms.ToPILImage()(heatmap))
plt.figure()
plt.imshow(transforms.ToPILImage()(cam_result))
plt.show()
if __name__ == '__main__':
main()
GradCAM 通过传入特定网络模块名称(layer4)指定需要分析注意力的网络层;并通过 class_idx 来指定分析哪个类别注意力,如果传入 None,程序会自动分析当前图片最终分类结果的那个类别。
class GradCAM(object):
"""Calculate GradCAM salinecy map.
A simple example:
# initialize a model, model_dict and gradcam
resnet = torchvision.models.resnet101(pretrained=True)
resnet.eval()
gradcam = GradCAM.from_config(model_type='resnet', arch=resnet, layer_name='layer4')
# get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
img = load_img()
normed_img = normalizer(img)
# get a GradCAM saliency map on the class index 10.
mask, logit = gradcam(normed_img, class_idx=10)
# make heatmap from mask and synthesize saliency map using heatmap and img
heatmap, cam_result = visualize_cam(mask, img)
"""
def __init__(self, arch: torch.nn.Module, target_layer: torch.nn.Module):
self.model_arch = arch
self.gradients = dict()
self.activations = dict()
def backward_hook(module, grad_input, grad_output):
self.gradients['value'] = grad_output[0]
def forward_hook(module, input, output):
self.activations['value'] = output
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
@classmethod
def from_config(cls, arch: torch.nn.Module, model_type: str, layer_name: str):
target_layer = layer_finders[model_type](arch, layer_name)
return cls(arch, target_layer)
def saliency_map_size(self, *input_size):
device = next(self.model_arch.parameters()).device
self.model_arch(torch.zeros(1, 3, *input_size, device=device))
return self.activations['value'].shape[2:]
def forward(self, input, class_idx=None, retain_graph=False):
"""
Args:
input: input image with shape of (1, 3, H, W)
class_idx (int): class index for calculating GradCAM.
If not specified, the class index that makes the highest model prediction score will be used.
Return:
mask: saliency map of the same spatial dimension with input
logit: model output
"""
b, c, h, w = input.size()
logit = self.model_arch(input)
if class_idx is None:
score = logit[:, logit.max(1)[-1]].squeeze()
else:
score = logit[:, class_idx].squeeze()
self.model_arch.zero_grad()
score.backward(retain_graph=retain_graph)
gradients = self.gradients['value']
activations = self.activations['value']
b, k, u, v = gradients.size()
alpha = gradients.view(b, k, -1).mean(2)
#alpha = F.relu(gradients.view(b, k, -1)).mean(2)
weights = alpha.view(b, k, 1, 1)
saliency_map = (weights*activations).sum(1, keepdim=True)
saliency_map = F.relu(saliency_map)
saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
return saliency_map, logit
def __call__(self, input, class_idx=None, retain_graph=False):
return self.forward(input, class_idx, retain_graph)
通过注册 hook 的方式,获取指定层 forward 的输出和 backward 的梯度输出
GradCAM 原本是基于分类网络设计的,直接迁移到检测网络中对某个输出结果进行分析可以得到结果,但是最终结果并不能很好的反应注意力情况。
实际应用还需要进一步探索。