pytorch注册hook获取网络任意中间层的特征和梯度

当需要获取网络中间层输出时,如果模型是自定义的,可以在模型定义时直接将想获取的那一层作为输出返回,如分类网络同时返回最终输出类别和最后一层特征:
在这里插入图片描述

但当我们调用封装好的模型,又不想重写模型,但还要获取网络中间输出时,就可使用hook机制

首先打印模型,查看待输出层的名称

print(model)

定义hook

pytorch注册hook获取网络任意中间层的特征和梯度_第1张图片
在这里插入图片描述
在这里插入图片描述
取特征:

feature = fmap_block[‘input’]

取梯度:

grad = grad_block[‘grad_in’]

参考:

https://www.jianshu.com/p/69e57e3526b3

你可能感兴趣的:(日常琐碎积累,pytorch,人工智能,python)