Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。
Grad-CAM的优点是它可以用于任何卷积神经网络,无需进行结构修改或重新训练。它为我们提供了一个简单但直观的方式来理解模型对于特定输入的决策。
import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.feature_maps = None
self.gradients = None
# Hook layers
target_layer.register_forward_hook(self.save_feature_maps)
target_layer.register_backward_hook(self.save_gradients)
def save_feature_maps(self, module, input, output):
self.feature_maps = output.detach()
def save_gradients(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate_cam(self, image, class_idx=None):
# Set model to evaluation mode
self.model.eval()
# Forward pass
output = self.model(image)
if class_idx is None:
class_idx = torch.argmax(output).item()
# Zero out gradients
self.model.zero_grad()
# Backward pass for target class
one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
one_hot[0][class_idx] = 1
output.backward(gradient=one_hot.cuda(), retain_graph=True)
# Get pooled gradients and feature maps
pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
activation = self.feature_maps.squeeze(0)
for i in range(activation.size(0)):
activation[i, :, :] *= pooled_gradients[i]
# Create heatmap
heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= torch.max(heatmap)
heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Superimpose heatmap on original image
original_image = self.unprocess_image(image.squeeze().cpu().numpy())
superimposed_img = heatmap * 0.4 + original_image
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
return heatmap, superimposed_img
def unprocess_image(self, image):
# Reverse the preprocessing step
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
return image
def visualize_gradcam(model, input_image_path, target_layer):
# Load image
img = Image.open(input_image_path)
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(img).unsqueeze(0).cuda()
# Create GradCAM
gradcam = GradCAM(model, target_layer)
heatmap, result = gradcam.generate_cam(input_tensor)
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(heatmap)
plt.title('Heatmap')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(result)
plt.title('Superimposed Image')
plt.axis('off')
plt.show()
# Load your model (e.g., resnet20 in this case)
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')
# Visualize GradCAM
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])
import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
class GradCAM:
def __init__(self, model, target_layer):
self.model = model # 要进行Grad-CAM处理的模型
self.target_layer = target_layer # 要进行特征可视化的目标层
self.feature_maps = None # 存储特征图
self.gradients = None # 存储梯度
# 为目标层添加钩子,以保存输出和梯度
target_layer.register_forward_hook(self.save_feature_maps)
target_layer.register_backward_hook(self.save_gradients)
def save_feature_maps(self, module, input, output):
"""保存特征图"""
self.feature_maps = output.detach()
def save_gradients(self, module, grad_input, grad_output):
"""保存梯度"""
self.gradients = grad_output[0].detach()
def generate_cam(self, image, class_idx=None):
"""生成CAM热力图"""
# 将模型设置为评估模式
self.model.eval()
# 正向传播
output = self.model(image)
if class_idx is None:
class_idx = torch.argmax(output).item()
# 清空所有梯度
self.model.zero_grad()
# 对目标类进行反向传播
one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
one_hot[0][class_idx] = 1
output.backward(gradient=one_hot.cuda(), retain_graph=True)
# 获取平均梯度和特征图
pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
activation = self.feature_maps.squeeze(0)
for i in range(activation.size(0)):
activation[i, :, :] *= pooled_gradients[i]
# 创建热力图
heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap /= torch.max(heatmap)
heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# 将热力图叠加到原始图像上
original_image = self.unprocess_image(image.squeeze().cpu().numpy())
superimposed_img = heatmap * 0.4 + original_image
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
return heatmap, superimposed_img
def unprocess_image(self, image):
"""反预处理图像,将其转回原始图像"""
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
return image
def visualize_gradcam(model, input_image_path, target_layer):
"""可视化Grad-CAM热力图"""
# 加载图像
img = Image.open(input_image_path)
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(img).unsqueeze(0).cuda()
# 创建GradCAM
gradcam = GradCAM(model, target_layer)
heatmap, result = gradcam.generate_cam(input_tensor)
# 显示图像和热力图
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.imshow(heatmap)
plt.title('热力图')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(result)
plt.title('叠加后的图像')
plt.axis('off')
plt.show()
# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')
# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])
论文链接:https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf