本文主要介绍如何使用pytorch获得已经训练好的网络的中间特征层,并将其转化为热力图的简单方法
效果图
1、在原本的test代码上进行修改
import matplotlib.pyplot as plt
2、随便写一个钩子函数(具体了解可以搜索“pytorch中的钩子(Hook)有何作用?”)
# 用于保存信息
output_list = []
input_list = []
# 定义hook方法(类似一个插件函数)
def forward_hook(module, data_input, data_output):
# 这里简单进行保存相关的特征层
# 也可以对特征层进行操作
input_list.append(data_input)
output_list.append(data_output)
3、然后注册一下钩子函数(在你需要保存的卷积层进行注册)
register_forward_hook为在前向传播时工作
如何查阅相关卷积层的名字,使用model.named_parameters()进行遍历查阅(本文不作详细解释)
# model.结构.某个卷积层.register_forward_hook(forward_hook)
model.det_head.conv2.register_forward_hook(forward_hook)
4、可视化一下
# 特征输出可视化
for i in range(6): # 可视化卷积相应的通道数量
# 以下绘制了一个宽度为6,高度为1的展示区域
plt.subplot(6, 1, i + 1)
plt.axis('off')
# 制定使用jet热力图展示,还有其他的展示形式
plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 保存起来,无白边保存
plt.savefig(file_path, bbox_inches='tight', pad_inches=0) ## 保存图片
# 在批量操作时,每次都会弹出来
# plt.show()
5、代码展示(不全,仅仅展示如何在test.py中添加相关代码)
# 1、导入包
import matplotlib.pyplot as plt
# 2、定义保存信息的数组
output_list = []
input_list = []
# 3、定义hook方法
def forward_hook(module, data_input, data_output):
input_list.append(data_input)
output_list.append(data_output)
# 主要代码如下!!!!
def test(test_loader, model, cfg):
model.eval()
# 4、进行注册hook方法
model.det_head.conv2.register_forward_hook(forward_hook)
for idx, data in enumerate(test_loader):
# 5、遍历操作、每次清空一下
output_list.clear()
input_list.clear()
print('Testing %d/%d\r' % (idx, len(test_loader)), flush=True, end='')
data.update(dict(cfg=cfg))
# 6、forward,hook函数会生效
with torch.no_grad():
outputs = model(**data)
# save result
image_name, _ = osp.splitext(
osp.basename(test_loader.dataset.img_paths[idx]))
rf.write_result(image_name, outputs)
# 7、生成热力图保存路径(自己按照自己的保存路径即可)
tmp_folder = cfg.test_cfg.result_path.replace('.zip', '_visualization')
file_name = '%s.jpg' % image_name
file_path = osp.join(tmp_folder, file_name)
# 8、特征输出可视化
for i in range(6): # 可视化了32通道
plt.subplot(1, 6, i + 1)
plt.axis('off')
plt.imshow(output_list[0].data.cpu().numpy()[0, i, :, :], cmap='jet')
# 9、保存中间特征层的热力图
plt.savefig(file_path, bbox_inches='tight', pad_inches=0) ## 保存图片
# plt.show() # 展示热力图,由于现在是批量操作,故注释
# 正常的模型加载操作,可忽略!!!按照自己的模型加载方法即可!!!
def main(args):
# 读取配置文件
cfg = Config.fromfile(args.config)
# data loader数据加载
data_loader = build_data_loader(cfg.data.test)
test_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=0,
)
# 模型加载
model = build_model(cfg.model)
model = model.cuda()
# 加载预训练权重
checkpoint = torch.load(args.checkpoint, map_location='cpu')
d = dict()
for key, value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
# test
test(test_loader, model, cfg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('config', help='config file path')
parser.add_argument('checkpoint', nargs='?', type=str, default=None)
parser.add_argument('--report_speed', action='store_true')
args = parser.parse_args()
main(args)