特征图可视化(pytorch)

本篇博客的可视化是可视化网络的每层特征图,不是指类激活图(CAM)可视化,CAM可视化可以参考Grad-Cam实现流程(pytorch)
这篇博客的目的仅是记录而已,由于距离上次使用过于久远,具体参考文章已经找不到,因此结尾未加入参考链接.
可视化效果如下图:

浅层

深层
特征图可视化(pytorch)_第1张图片
代码
利用tensorboard可视化特征图,以VGG16为例.

import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torch
import torch.nn.functional as F

# ----------------------------------- feature map visualization -----------------------------------

writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

# 数据加载及预处理
path_img = "./Forsters_Tern_0016_152463.jpg"     # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]

norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    norm_transform])
img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
    img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

# 模型加载
vggnet = models.vgg16_bn(pretrained=False)
pthfile = './pretrained/vgg16_bn-6c64b313.pth'
vggnet.load_state_dict(torch.load(pthfile))
# print(vggnet)

# 注册hook
fmap_dict = dict()
n = 0
# for name, sub_module in vggnet.named_modules():  # named_modules()返回网络的子网络层及其名称
#     if isinstance(sub_module, nn.Conv2d):
#         n += 1
#         print('Conv_'+str(n)+'_'+name)

def hook_func(m, i, o):
    # print(m)
    key_name = str(m.weight.shape)
    fmap_dict[key_name].append(o)

for name, sub_module in vggnet.named_modules():  # named_modules()返回网络的子网络层及其名称
    if isinstance(sub_module, nn.Conv2d):
        n += 1
        key_name = str(sub_module.weight.shape)
        # key_name = 'Conv_'+str(n)
        # Python 字典 setdefault() 函数和 get()方法 类似, 如果键不存在于字典中,将会添加键并将值设为默认值。
        fmap_dict.setdefault(key_name, list())
        # print(fmap_dict,'\n')

        n1, n2 = name.split(".")
        # print(n1,n2)
            # print(fmap_dict,'\n')
        # print(name)
        # print('1',vggnet._modules[n1]._modules[n2].named_modules())
        vggnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

# forward
output = vggnet(img_tensor)
print(fmap_dict['torch.Size([128, 64, 3, 3])'][0].shape)
# add image
for layer_name, fmap_list in fmap_dict.items():
    fmap = fmap_list[0]
    # print(fmap.shape)
    fmap.transpose_(0, 1)
    # print(fmap.shape)

    nrow = int(np.sqrt(fmap.shape[0]))
    # if layer_name == 'torch.Size([512, 512, 3, 3])':
    fmap = F.interpolate(fmap, size=[112, 112], mode="bilinear")
    
    fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
    print(type(fmap_grid),fmap_grid.shape)
    writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

你可能感兴趣的:(pytorch)