可视化学习笔记11-pytorch-GradCAM可视化自己的网络

续可视化笔记2-pytorch 可视化卷积网络中间特征层的基础上使用CAM方法可视化网络对待测对象关注的位置。

1.定义GradCAM类

注意:代码中需要改的3个地方已经用注释标清,大家使用时注意修改。

class GradCAM(nn.Module):
    def __init__(self):
        super(GradCAM, self).__init__()
        # 获取模型的特征提取层
        self.feature = nn.Sequential(OrderedDict({
            name: layer for name, layer in model.named_children()
            if name not in ['avgpool', 'classifier']#改1:根据自己的网络模型架构调整。
        }))
        # 获取模型最后的平均池化层
        self.avgpool = model.avgpool
        # 获取模型的输出层
        self.classifier = nn.Sequential(OrderedDict([
            ('classifier', model.classifier)#改2:模型剩什么层就写什么层(这里我的网络除了avgpool就只剩classifier)。
        ]))
        # 生成梯度占位符
        self.gradients = None

    # 获取梯度的钩子函数
    def activations_hook(self, grad):
        self.gradients = grad

    def forward(self, x):
        x = self.feature(x)
        # 注册钩子
        h = x.register_hook(self.activations_hook)
        # 对卷积后的输出使用平均池化
        x = self.avgpool(x)
        x = x.view((1, -1))
        x = self.classifier(x)#改3:同2
        return x

    # 获取梯度的方法
    def get_activations_gradient(self):
        return self.gradients

    # 获取卷积层输出的方法
    def get_activations(self, x):
        return self.feature(x)

2.获取热力图

# 获取热力图
def get_heatmap(model, img):
    model.eval()
    img_pre = model(img)
    # 获取预测最高的类别
    pre_class = torch.argmax(img_pre, dim=-1).item()
    # 获取相对于模型参数的输出梯度
    img_pre[:, pre_class].backward()
    # 获取模型的梯度
    gradients = model.get_activations_gradient()
    # 计算梯度相应通道的均值
    mean_gradients = torch.mean(gradients, dim=[0, 2, 3])
    # 获取图像在相应卷积层输出的卷积特征
    activations = model.get_activations(input_im).detach()
    # 每个通道乘以相应的梯度均值
    for i in range(len(mean_gradients)):
        activations[:, i, :, :] *= mean_gradients[i]
    # 计算所有通道的均值输出得到热力图
    heatmap = torch.mean(activations, dim=1).squeeze()
    # 使用Relu函数作用于热力图
    heatmap = F.relu(heatmap)
    # 对热力图进行标准化
    heatmap /= torch.max(heatmap)
    heatmap = heatmap.numpy()

    return heatmap
cam = GradCAM()
# 获取热力图
heatmap = get_heatmap(cam, input_im)
# 可视化热力图
plt.matshow(heatmap)
plt.show()

3.显示结果

# 合并热力图和原图,并显示结果
def merge_heatmap_image(heatmap, image_path):
    img = cv2.imread(image_path)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    grad_cam_img = heatmap * 0.7 + img
    grad_cam_img = grad_cam_img / grad_cam_img.max()
    # 可视化图像
    b,g,r = cv2.split(grad_cam_img)
    grad_cam_img = cv2.merge([r,g,b])

    plt.figure(figsize=(8,8))
    plt.imshow(grad_cam_img)
    plt.axis('off')
    plt.savefig("./CAM/CBAM_fig2")
    plt.show()
merge_heatmap_image(heatmap, img_path)

通过修改img_path变量,重新运行代码,及可以得到其他图片的CAM结果。这里仅展示部分结果:
可视化学习笔记11-pytorch-GradCAM可视化自己的网络_第1张图片

你可能感兴趣的:(可视化学习,pytorch,学习,网络)