pytorch学习笔记十五:Hook函数与CAM可视化

一、Hook函数概念

Hook函数机制:不改变模型主体,实现额外功能,像一个挂件或挂钩等。

为什么需要这个函数呢?这与Pytorch的动态图计算机制有关,在动态图的计算过程中,一些中间变量会释放掉,比如特征图、非叶子节点的梯度,在模型前向传播、反向传播的时候添加hook这个额外函数,提取一些释放掉而后面又需要用到的变量,也可以用hook函数来改变中间变量的梯度。

Pytorch中提供四种hook函数:
1、torch.Tensor.register_hook(hook): 针对tensor
2、torch.nn.Module.register_forward_hook:后面这三个针对Module
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook

二、Hook函数与特征提取

1、torch.Tensor.register_hook()

功能:这是一个针对张量的hook函数,作用是注册一个反向传播的hook函数,为什么是在反向传播呢?因为只有在反向传播过程中非叶子的梯度会释放掉,用hook函数来保存这些中间变量的信息。

hook(grad) -> Tensor or None

hook函数仅有一个输入参数为张量的梯度,返回值是tensor或者none
例如:
下图是pytorch中一个简单的计算图与梯度求导
pytorch学习笔记十五:Hook函数与CAM可视化_第1张图片
在上面计算图反响传播过程中,非叶子节点a和b的梯度会释放掉,在前面的学习中可知retain_grad()可保留参数的梯度,也可用hook函数来保留梯度,如下所示:

# 构建计算图,在反向传播中用hook来保存a的梯度
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# 构建一个list用来存储a的梯度
a_grad = list()

# 自定义hook函数,存放a的梯度,然后将a的梯度存放到前面构建的list中
def grad_hook(grad):
    a_grad.append(grad)

# 接受一个hook函数的钩子,相当于把hook函数挂到计算图上,这样在反向传播时可以保存a的梯度
handle = a.register_hook(grad_hook)

y.backward()

# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()

输出结果:
在这里插入图片描述
可看出在反向传播结束后是将a和b的梯度释放掉了,而hook函数则是保留了a的梯度,这样可以方便后续的使用。另外hook函数可以在反向传播中改变节点的梯度值,如下:

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

# 改变节点的梯度值,在hook里可以实现具体的改变方式,并用return返回
def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 查看梯度
print("w.grad: ", w.grad)
handle.remove()

输出结果:
在这里插入图片描述
通过hook函数的变化之后w的梯度变为原来的6倍。

2、Module.register_forward_hook

hook(module, input, output) -> None

功能:注册module前向传播的hook函数
model:当前的网络层
input:当前网络层输入的数据
output:当前网络层输出数据

3、Module.register_forward_pre_hook

hook(module, input) -> None

功能:注册module前向传播的hook函数
module:当前的网络层
input:当前网络层的输入数据
因为这个hook函数是用在前向传播前的函数,所以这里接受参数之后就没有返回值,这个功能可以查看网络之前的数据。

4、Module.register_backward_hook

hook(module, grad_input, grad_output) -> Tensor or None

功能:注册module反向传播的hook函数
module:当前网络层
grad_input:当前网络层的输入梯度数据
grad_output:当前网络层的输出梯度数据

以上就是Pytorch中的hook函数,第一个是针对tensor,后三个是针对module,根据hook函数的使用位置可分为前向传播前,前向传播,反向传播。下面通过具体的示例来了解一下:

假设输入是44的图像经过33的卷积之后得到2*2的feature map,然后经过池化得到后面的输出值,下面就用hook函数来获取中间的feature map层
pytorch学习笔记十五:Hook函数与CAM可视化_第2张图片

#根据上图的示例,构建一个网络,只有卷积和池化两个操作
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x
        
#定义前向传播的hook函数
def forward_hook(module, data_input, data_output):
   fmap_block.append(data_output)
   input_block.append(data_input)

#定义前向传播前的hook函数
def forward_pre_hook(module, data_input):
   print("forward_pre_hook input:{}".format(data_input))

#定义反向传播的hook函数
def backward_hook(module, grad_input, grad_output):
   print("backward hook input:{}".format(grad_input))
   print("backward hook output:{}".format(grad_output))

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

输出结果:
pytorch学习笔记十五:Hook函数与CAM可视化_第3张图片
在output = net(fake_img)处打上断点,查看一下上面的三个hook函数是如何实现的。debug进入到module.py中的_call_impl函数中,在这里会调用前向传播函数:
pytorch学习笔记十五:Hook函数与CAM可视化_第4张图片
进一步debug,会进入到构建的网络前向传播函数中
pytorch学习笔记十五:Hook函数与CAM可视化_第5张图片
这是网络的第一个字模块,也就是卷积模块,这里也是定义钩子的地方,进一步debug,可看到又进入到了module.py文件中的_call_impl函数中,仔细观察_call_impl函数可看到主要有四个模块

在上面的示例中设置了三个钩子,分别是前向传播之前,前向传播,反向传播,在不同的过程中会调用_call_impl函数的对应的模块,比如forward_pre_hook钩子会对应上面的第一个模块,然后在result=hook(self, input)会跳到自定义的钩子函数中。继续debug可看到其他钩子也是如此。

【总结】
上面的hook函数的运行机制,都是在module中的_call_impl函数中实现,这个函数完成了4部分的工作,前向传播之前的hook函数(这里钩子主要是查看输入数据的信息),前向传播,forward hook函数(这里的钩子接受参数的输入和输出,存储中间特征图的信息),backward hook函数(这里的钩子常是查看参数的梯度信息)。总体来说hook机制就是在计算图上挂一些钩子,然后钩子上定义一些函数,在不改变模型或者计算图主体的情况下,提供了一些实现别的额外功能的接口。

三、CAM可视化

CAM:类激活图, class activation map。主要功能就是分析卷积神经网络,图像通过卷积神经网络得到了输出之后,可以分析网络是关注图像的哪些部分而得到的这个结果。通过这个可以分析出网络是否学习到了图片中物体本身的特征信息, 如下所示的过程图:
pytorch学习笔记十五:Hook函数与CAM可视化_第6张图片
论文:《Learning Deep Features for Discriminative Localization》

上面网络最后的输出是澳大利亚犬种。那么网络从图像中看到了什么东西才确定是这一个类呢?这里通过CAM算法进行一个可视化,结果就如图中所示。红色的就是网络重点关注的, 在这个结果中看以发现,这个网络重点关注了狗的头部,最后判定是一个这样的犬种。

CAM的基本思想:它会对网络的最后一个特征图进行加权求和,就可以得到一个注意力机制,就是卷积神经网络更关注于什么地方。那如何得到这些特征图的权值呢?对每一个feature map进行golbal average pooling就得到其对应的权值,再通过加权求和最后得到 class activation map。

缺点:CAM是通过golbal average pooling得到权值的,如果输入值改变就得重新训练网络得到权重值,所以就有了如下的改进算法

Grad-CAM:CAM改进版,利用梯度作为特征图权重
pytorch学习笔记十五:Hook函数与CAM可视化_第7张图片
具体思想:根据最后网络输出的向量值进行backward,求出feature map中每一个像素值对应的梯度值,将feature map每一个像素值对应的梯度值进行平均,将梯度的平均值作为此feature map的权重值,然后进行加权求和得到 CAM.
论文:Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

你可能感兴趣的:(pytorch,深度学习,pytorch)