一个库可视化类激活热力图Grad-CAM pytorch版本

一个库可视化类激活热力图Grad-CAM pytorch版本 绘制特征图

参考内容
https://blog.csdn.net/u014264373/article/details/116302678
https://github.com/jacobgil/pytorch-grad-cam

本文在colab上运行

1. 安装CAM库(1.3.7版本),并导入相关的包

!pip install grad-cam

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image, preprocess_image
import cv2
import numpy as np
import torchvision
import torch
from PIL import Image
import matplotlib.pyplot as plt

# import os
# # MacOS系统应该加这行
# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

2. 加载模型

# 1.加载模型
model = torchvision.models.vgg16(pretrained=True)    

3. 选择目标层

# 2.选择目标层
target_layer = [model.features[-1]]  #vgg16
#target_layer = [model.layer4[-1]]   #resnet50
'''
Resnet18 and 50: model.layer4[-1]
VGG and densenet161: model.features[-1]
mnasnet1_0: model.layers[-1]
ViT: model.blocks[-1].norm1
'''

4. 输入图像

# 3.输入图像
img_path = '/content/drive/MyDrive/African elephant.jpeg'
img = Image.open(img_path).convert('RGB')
#img = img.resize((224,224))
# 一转,'x' is a float32 Numpy array of shape
img = np.array(img)
img_float_np = np.float32(img)/255  #归一化
# define the torchvision image transforms
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

input_tensor = transform(img)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_tensor = input_tensor.to(device)

# 二扩,add a dimension to transform the array into a "batch"
input_tensor = input_tensor.unsqueeze(0)

###### x.shape (1,3,447,670)

一个库可视化类激活热力图Grad-CAM pytorch版本_第1张图片网上随便找的一张图,大小尺寸为670*447

5. 初始化GradCAM

# 4.初始化GradCAM,包括模型,目标层以及是否使用cuda
cam = GradCAM(model=model, target_layers=target_layer, use_cuda=True)

6. 选定目标类别

# 5.选定目标类别,如果不设置,则默认为分数最高的那一类
targets = None 
# targets = [ClassifierOutputTarget(281)] 第281类

7. 计算CAM

# 6. 计算cam
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)  
# 加上 aug_smooth=True(应用水平翻转组合,并通过[1.0,1.1,0.9]对图像进行多路复用,使CAM围绕对象居中)
# eigen_smooth=True(去除大量噪声)

8. 绘图

# 7.展示热力图并保存, grayscale_cam是一个batch的结果,只能选择一张进行展示
grayscale_cam = grayscale_cam[0,:]
cam_image = show_cam_on_image(img_float_np, grayscale_cam, use_rgb=True)
cv2.imwrite(f'/content/African elephant.jpeg', cam_image)

可视化后的结果为:
一个库可视化类激活热力图Grad-CAM pytorch版本_第2张图片

你可能感兴趣的:(pytorch,深度学习,python)