Grad-CAM:main_cnn.py源码解析

        论文地址:https://arxiv.org/abs/1610.02391

        另一篇代码解析:Grad-CAM:utils.py源码解析

        代码用的是根据官方代码精炼后的一个代码,自己做了注解:

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from utils import GradCAM, show_cam_on_image, center_crop_img


def main():
    # model = models.mobilenet_v3_large(pretrained=True)
    # target_layers = [model.features[-1]]

    model = models.vgg16(pretrained=True)
    target_layers = [model.features]

    # model = models.resnet34(pretrained=True)
    # target_layers = [model.layer4]

    # model = models.regnet_y_800mf(pretrained=True)
    # target_layers = [model.trunk_output]

    # model = models.efficientnet_b0(pretrained=True)
    # target_layers = [model.features]

    # 定义图像预处理方式
    data_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "both.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path).convert('RGB')  # 读取图片并转换为RGB
    img = np.array(img, dtype=np.uint8)  # 转换成NumPy格式
    # img = center_crop_img(img, 224)

    # [C, H, W]
    img_tensor = data_transform(img)
    # expand batch dimension
    # [C, H, W] -> [N, C, H, W]
    input_tensor = torch.unsqueeze(img_tensor, dim=0)  # 增加batch维度

    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
    target_category = 281  # 指定感兴趣的类别,tabby, tabby cat
    # target_category = 254  # pug, pug-dog

    grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)  # 传到__call__方法里

    grayscale_cam = grayscale_cam[0, :]  # 把第一张图片的grayscale_cam提取出来;这里只传入了1张图片
    # 绘制最终的热力图
    visualization = show_cam_on_image(img.astype(dtype=np.float32) / 255.,  # 将原图像素缩放到0-1之间
                                      grayscale_cam,
                                      use_rgb=True)
    plt.imshow(visualization)
    plt.show()


if __name__ == '__main__':
    main()

你可能感兴趣的:(Deep,Learning,Tricks,#,可视化,cnn,深度学习,python,pytorch)