Pytorch CAM特征可视化

背景

       类别激活映射(Class Activation Mapping, CAM)用于对深度学习特征可视化,通过特征响应定位图像的关键部位,为深度学习可解释性提供了一种方法,ACM以热力图的方式展示了图像局部响应的强弱信息,对应于更强的位置具有更好的特征识别能力。

论文链接:Learning Deep Features for Discriminative Localization

CAM基本原理:

    定义类别分数 S_c = \sum_kw_k^c \sum_{x,y}f_k(x,y) = \sum_{x,y}\sum_kw_k^cf_k(x,y),其中f_k(x,y)表示最后一个卷积层第k通道的输出,w_k^c为第k个通道对应的类别c的权重,定义CAM对第C类的映射M_c,则有M_c(x,y) = \sum_kw_k^cf_k(x,y)

CAM相关方法:Grad-CAM: https://arxiv.org/pdf/1610.02391.pdf、Grad-CAM++: https://arxiv.org/pdf/1610.02391.pdf

基于Resnet50的特征可视化代码:

import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

os.environ["KMP_DUPLICATE_LIB_OK"]="True"

def draw_cam(model, img_path, save_path, transform=None, visheadmap=False):
    img = Image.open(img_path).convert('RGB')
    if transform is not None:
        img = transform(img)
    img = img.unsqueeze(0)
    model.eval()
    x = model.conv1(img)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    features = x                #1x2048x7x7
    print(features.shape)
    output = model.avgpool(x)   #1x2048x1x1
    print(output.shape)
    output = output.view(output.size(0), -1)
    print(output.shape)         #1x2048
    output = model.fc(output)   #1x1000
    print(output.shape)
    def extract(g):
        global feature_grad
        feature_grad = g
    pred = torch.argmax(output).item()
    pred_class = output[:, pred]
    features.register_hook(extract)
    pred_class.backward()
    greds = feature_grad
    pooled_grads = torch.nn.functional.adaptive_avg_pool2d(greds, (1, 1))
    pooled_grads = pooled_grads[0]
    features = features[0]
    for i in range(2048):
        features[i, ...] *= pooled_grads[i, ...]
    headmap = features.detach().numpy()
    headmap = np.mean(headmap, axis=0)
    headmap /= np.max(headmap)

    if visheadmap:
        plt.matshow(headmap)
        # plt.savefig(headmap, './headmap.png')
        plt.show()

    img = cv2.imread(img_path)
    headmap = cv2.resize(headmap, (img.shape[1], img.shape[0]))
    headmap = np.uint8(255*headmap)
    headmap = cv2.applyColorMap(headmap, cv2.COLORMAP_JET)
    superimposed_img = headmap*0.4 + img
    cv2.imwrite(save_path, superimposed_img)

if __name__ == '__main__':
     model = models.resnet50(pretrained=True)
     transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
     draw_cam(model, './1.jpg', './cam_1.png', transform=transform, visheadmap=True)

效果展示:

Pytorch CAM特征可视化_第1张图片 Pytorch CAM特征可视化_第2张图片

项目地址:sourceCode

你可能感兴趣的:(Pytorch,Pytorch,ACM,特征可视化)