pytorch中有两种钩子:
一个模型如VGG16是由很多的模块(module)组成的。但是我们在用别人写好了的VGG16的时候,你只能获取到最后的分类结果。当我们想获得其中一些模块如第一个卷积层的输出feature该怎么办呢?没错就是用模块钩子把这一模块的输出特征图钩出来!!!
分为两步:
参考链接1
参考链接2
hook for tensor。
注意:
hook函数中不可以修改其输入。但是可以选择性的返回一个新的梯度来代替当前的梯度。(即不可以修改钩子函数的参数)(钩子函数是具有修改网络梯度的能力的。)
通常的使用方法是定义一个全局变量,在钩子函数中接收需要获取的梯度。或者直接将梯度打印出来。
import torch
from torch.autograd import Variable
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data
# 定义钩子
# 定义字典存储不同层的输出
activation={}
# 定义钩子函数
def get_activation(name):
def hook(model,input,output): # 此处的参数input以及output就是调用钩子函数的module的输入和输出。
activation[name]=output.detach()
return hook
# 目的:可视化features[4]的特征图,也就是第一个maxpool之后的输出特征
# 使用钩子抓取:
vgg16.eval()
vgg16.features[4].register_forward_hook(get_activation('maxpool1')) # 注册钩子
vgg16(input_im) # 网络计算
maxpool1=activation['maxpool1']
print(maxpool1.shape)
//输出:
torch.Size([1, 64, 112, 112])
# 定义钩子
# 定义字典存储不同层的输出
activation={}
# 定义钩子函数
def get_activation(name):
def hook(model,input,output):
activation[name]=output.detach() #detach()表示将该运算从计算图上分离出来
return hook
# 目的:可视化features[4]的特征图,也就是第一个maxpool之后的输出特征
# 使用钩子抓取:
vgg16.eval()
vgg16.features[4].register_forward_hook(get_activation('maxpool1')) # 注册钩子
vgg16(input_im) # 网络计算
maxpool1=activation['maxpool1']
# 将特征图可视化
plt.figure(figsize=(11,6))
for i in range(maxpool1.size(1)):
plt.subplot(6,11,i+1)
plt.imshow(maxpool1.data.numpy()[0,i,:,:],cmap='gray')
plt.axis('off')
plt.subplots_adjust(wspace=0.1,hspace=0.1)
plt.show()
Grad-CAM详解参考链接
Grad-CAM就是利用分类网络最后的分类得分求解最后一个卷积层的每个特征图的梯度,将该梯度看做每个特征图的得分,然后将得分与特征图进行加权求和得到最后的激活热力图。
# 要实现Grad-CAM,需要最后一个卷积层的输出以及其梯度。
# 定义钩子获取最后一个卷积层的输出(特征图)和梯度
class MyVgg16(nn.Module):
def __init__(self):
super(MyVgg16,self).__init__()
self.vgg=models.vgg16(pretrained=True)
# 因为要获取最后一个卷积层的输出,因此将该层与后面的池化层分开,以便使用hook
self.features_conv=self.vgg.features[:30]
self.maxpool=self.vgg.features[30]
self.avgpool=self.vgg.avgpool
self.classifier=self.vgg.classifier
# 定义gradient的占位符
self.gradient=None
# 获取梯度的钩子函数
def activation_hook(self,grad):
self.gradient=grad
def forward(self,x):
x=self.features_conv(x)
# 注册一个backward钩子,每次梯度被计算到时都会调用该钩子
h=x.register_hook(self.activation_hook)
x=self.maxpool(x)
x=self.avgpool(x)
x=x.view((1,-1))
x=self.classifier(x)
return x
# 获取梯度得方法
def get_activation_grad(self):
return self.gradient
# 获取卷积层输出的方法
def get_activation_feat(self,x):
return self.features_conv(x)
# 将图像输入到网络中进行分类
net=MyVgg16()
output=net(input_im)
pre_prob=nn.Softmax(dim=1)(output)
# 获取网络输出概率排名前五的输出
pre,pre_index=torch.topk(pre_prob,5) # torch.topk(tensor,k):返回值是两个tensor分别是:topk的tensor值以及对应的index
pre=pre.data.squeeze(0).numpy()
pre_index=pre_index.data.squeeze(0).numpy()
# 第一步:获取预测的可能性最好的类别的对应特征图的梯度
output[:,pre_index[0]].backward()
gradient=net.get_activation_grad()
# print(gradient.shape) # 输出:torch.Size([1, 512, 14, 14])
# 第二步:获取最后一个卷积层的特征图
features=net.get_activation_feat(input_im).detach()
# print(features.shape) # 输出:torch.Size([1, 512, 14, 14])
# 将梯度与对应的特征图进行加权求和得到热力图
# 首先,计算梯度的每个通道的响应均值
gradient=gradient.squeeze(0)
print(gradient.shape) # 输出:torch.Size([512])
gradient_mean=torch.mean(gradient.view(gradient.shape[0],-1),dim=1)
print(gradient_mean.shape) # 输出:torch.Size([512])
# 然后,将每个通道的均值与对应的特征图进行加权
for i in range(gradient_mean.shape[0]):
features[:,i,:,:]*=gradient_mean[i].data
# 最后,将每个通道的加权后的特征图求平均得到最终的热力图
heatmap=torch.mean(features,dim=1).squeeze(0)
# print(feature_cam.shape) # 输出:torch.Size([14, 14])
heatmap=F.relu(heatmap)
heatmap/=torch.max(heatmap)
heatmap=heatmap.numpy()
# 可视化
plt.matshow(heatmap)
# 将Crad-CAM热力图融合到原始图像上
# 读取原始图像
img=cv2.imread('./data/chap6/大象.jpg')
# print(type(img)) # 输出:numpy
# print(img.shape) # 输出:(365, 550, 3) cv2读取图像的维度为h*w*c
heatmap=cv2.resize(heatmap,(img.shape[1],img.shape[0]))
heatmap=np.uint8(255*heatmap)
heatmap=cv2.applyColorMap(heatmap,cv2.COLORMAP_JET)
grad_cam_img=heatmap*0.4+img
grad_cam_img=grad_cam_img/grad_cam_img.max()
# 可视化图像
b,g,r=cv2.split(grad_cam_img)
grad_cam_img=cv2.merge([r,g,b])
plt.figure()
plt.imshow(grad_cam_img)
plt.show()