使用pytorch中的钩子将特征图和梯度勾出来,从而达到可视化特征图(featuremap)和可视化热图(heatmap)的目的。
提示:以下是本篇文章正文内容,下面案例可供参考
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torch
import torch.nn.functional as F
from ipdb import set_trace
# ----------------------------------- feature map visualization -----------------------------------
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
# 数据加载及预处理
path_img = "./dog.jpg" # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
norm_transform])
img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0) # chw --> bchw
# 模型加载
vggnet = models.vgg16_bn(pretrained=False)
pthfile = './pretrained/vgg16_bn-6c64b313.pth'
vggnet.load_state_dict(torch.load(pthfile))
# 注册hook
fmap_dict = dict()
n = 0
def hook_func(m, i, o):
key_name = str(m.weight.shape)
fmap_dict[key_name].append(o)
for name, sub_module in vggnet.named_modules():
if isinstance(sub_module, nn.Conv2d):
n += 1
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list())
n1, n2 = name.split(".")
vggnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
# forward
output = vggnet(img_tensor)
print(fmap_dict['torch.Size([128, 64, 3, 3])'][0].shape)
# add image
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]
# print(fmap.shape)
fmap.transpose_(0, 1)
# print(fmap.shape)
nrow = int(np.sqrt(fmap.shape[0]))
# if layer_name == 'torch.Size([512, 512, 3, 3])':
fmap = F.interpolate(fmap, size=[112, 112], mode="bilinear")
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
print(type(fmap_grid),fmap_grid.shape)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
# 程序运行完后在终端运行tensorboard --logdir=***(tf文件路径)
import cv2
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
import json
from ipdb import set_trace
# 图片预处理
def img_preprocess(img_in):
img = img_in.copy()
img = img[:, :, ::-1] # 1
img = np.ascontiguousarray(img) # 2
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4948052, 0.48568845, 0.44682974], [0.24580306, 0.24236229, 0.2603115])
])
img = transform(img)
img = img.unsqueeze(0) # 3
return img
# 定义获取梯度的函数
def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach())
# 定义获取特征图的函数
def farward_hook(module, input, output):
fmap_block.append(output)
# 计算grad-cam并可视化
def cam_show_img(img, feature_map, grads, out_dir):
H, W, _ = img.shape
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # 4
grads = grads.reshape([grads.shape[0],-1]) # 5
weights = np.mean(grads, axis=1) # 6
for i, w in enumerate(weights):
cam += w * feature_map[i, :, :] # 7
cam = np.maximum(cam, 0)
cam = cam / cam.max()
cam = cv2.resize(cam, (W, H))
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
cam_img = 0.3 * heatmap + 0.7 * img
path_cam_img = os.path.join(out_dir, "cam.jpg")
cv2.imwrite(path_cam_img, cam_img)
if __name__ == '__main__':
# path_img = './cam/bicycle.jpg'
path_img = './cam/test.png'
json_path = './cam/labels.json'
output_dir = './cam'
with open(json_path, 'r') as load_f:
load_json = json.load(load_f)
classes = {int(key): value for (key, value)
in load_json.items()}
# 只取标签名
classes = list(classes.get(key) for key in range(1000))
# 存放梯度和特征图
fmap_block = list()
grad_block = list()
# 图片读取;网络加载
img = cv2.imread(path_img, 1)
img_input = img_preprocess(img)
# 加载 squeezenet1_1 预训练模型
net = models.resnet.resnet50(pretrained=True)
# pthfile = './squeezenet1_1-f364aa15.pth'
# net.load_state_dict(torch.load(pthfile))
net.eval() # 8
print(net)
# 注册hook
# set_trace()
# net.features[-1].expand3x3.register_forward_hook(farward_hook) # 9
# net.features[-1].expand3x3.register_backward_hook(backward_hook)
net.layer4[-1].register_forward_hook(farward_hook) # 9
net.layer4[-1].register_backward_hook(backward_hook)
# forward
output = net(img_input)
idx = np.argmax(output.cpu().data.numpy())
print("predict: {}".format(classes[idx]))
# backward
net.zero_grad()
class_loss = output[0,idx]
class_loss.backward()
# 生成cam
grads_val = grad_block[0].cpu().data.numpy().squeeze()
fmap = fmap_block[0].cpu().data.numpy().squeeze()
# 保存cam图片
cam_show_img(img, fmap, grads_val, output_dir)
以上通过pytorch中的钩子进行特征图可视化和热图可视化,均是针对图像分类的。其中特征图可视化可以直接迁移到目标检测,语义分割等其他目标,但是热图却不能直接转换,未来会增加对目标检测热图的可视化相关工作。