pytorch对中间特征层可视化方案

本文主要介绍如何使用pytorch获得已经训练好的网络的中间特征层,并将其转化为热力图的简单方法

效果图

1、在原本的test代码上进行修改

import matplotlib.pyplot as plt

2、随便写一个钩子函数(具体了解可以搜索“pytorch中的钩子(Hook)有何作用?”)

# 用于保存信息
output_list = []
input_list = []


# 定义hook方法(类似一个插件函数)
def forward_hook(module, data_input, data_output):
    # 这里简单进行保存相关的特征层
    # 也可以对特征层进行操作
    input_list.append(data_input)
    output_list.append(data_output)

3、然后注册一下钩子函数(在你需要保存的卷积层进行注册)

        register_forward_hook为在前向传播时工作

        如何查阅相关卷积层的名字,使用model.named_parameters()进行遍历查阅(本文不作详细解释)

# model.结构.某个卷积层.register_forward_hook(forward_hook)
model.det_head.conv2.register_forward_hook(forward_hook)

4、可视化一下

# 特征输出可视化
for i in range(6):  # 可视化卷积相应的通道数量
    # 以下绘制了一个宽度为6,高度为1的展示区域
    plt.subplot(6, 1, i + 1)
    plt.axis('off')
    # 制定使用jet热力图展示,还有其他的展示形式
    plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 保存起来,无白边保存
plt.savefig(file_path, bbox_inches='tight', pad_inches=0)  ## 保存图片
# 在批量操作时,每次都会弹出来
# plt.show()  

5、代码展示(不全,仅仅展示如何在test.py中添加相关代码)

# 1、导入包
import matplotlib.pyplot as plt


# 2、定义保存信息的数组
output_list = []
input_list = []


# 3、定义hook方法
def forward_hook(module, data_input, data_output):
    input_list.append(data_input)
    output_list.append(data_output)


# 主要代码如下!!!!
def test(test_loader, model, cfg):
    model.eval()
    
    # 4、进行注册hook方法
    model.det_head.conv2.register_forward_hook(forward_hook)

    for idx, data in enumerate(test_loader):
        # 5、遍历操作、每次清空一下
        output_list.clear()
        input_list.clear()
        print('Testing %d/%d\r' % (idx, len(test_loader)), flush=True, end='')
        data.update(dict(cfg=cfg))


        # 6、forward,hook函数会生效
        with torch.no_grad():
            outputs = model(**data)

        # save result
        image_name, _ = osp.splitext(
            osp.basename(test_loader.dataset.img_paths[idx]))
        rf.write_result(image_name, outputs)

        # 7、生成热力图保存路径(自己按照自己的保存路径即可)
        tmp_folder = cfg.test_cfg.result_path.replace('.zip', '_visualization')
        file_name = '%s.jpg' % image_name
        file_path = osp.join(tmp_folder, file_name)
        # 8、特征输出可视化
        for i in range(6):  # 可视化了32通道
            plt.subplot(1, 6, i + 1)
            plt.axis('off')
            plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
        # 9、保存中间特征层的热力图
        plt.savefig(file_path, bbox_inches='tight', pad_inches=0)  ## 保存图片
        # plt.show()  # 展示热力图,由于现在是批量操作,故注释


# 正常的模型加载操作,可忽略!!!按照自己的模型加载方法即可!!!
def main(args):
    # 读取配置文件
    cfg = Config.fromfile(args.config)

    # data loader数据加载
    data_loader = build_data_loader(cfg.data.test)
    test_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=1,
        shuffle=False,
        num_workers=0,
    )
    
    # 模型加载
    model = build_model(cfg.model)
    model = model.cuda()

    # 加载预训练权重
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    d = dict()
    for key, value in checkpoint['state_dict'].items():
        tmp = key[7:]
        d[tmp] = value
    model.load_state_dict(d)

    # test
    test(test_loader, model, cfg)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Hyperparams')
    parser.add_argument('config', help='config file path')
    parser.add_argument('checkpoint', nargs='?', type=str, default=None)
    parser.add_argument('--report_speed', action='store_true')
    args = parser.parse_args()

    main(args)

你可能感兴趣的:(文本检测,pytorch,计算机视觉,目标检测)