Grad-CAM可视化

图片读取:

from PIL import Image

img = Image.open(img_path).convert('RGB')

第一种读取方式: 

import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
import cv2
import torch.autograd as autograd

def draw_CAM(model, data, rgb_bands=(18, 28, 38), visual_heatmap=False):
    '''
    绘制 Class Activation Map
    :param model: 加载好权重的Pytorch model
    :param data: 数据。维度为(h,w,c)
    :param rgb_bands: 选择可视化的通道
    :param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
    :return:
    '''

    model.eval()
    preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])

    # 图像加载&预处理
    img = preprocess(data)
    img = Variable(img.unsqueeze(0))

    # 获取模型输出的feature/score
    output, features = model(img) # ★ 此处在模型中

你可能感兴趣的:(记录本,大数据)