pytorch使用笔记(二):模型钩子(Hook for Modules)的使用

前言

pytorch中有两种钩子:Hook for Tensor和Hook for Modules。在本文中只介绍后者,因为后者更为常用:)。模型钩子分为两种:钩forward信息流的钩子和钩backwar信息流的钩子。

 

为什么要使用钩子?

一个模型如VGG16是由很多的模块(module)组成的。但是我们在用别人写好了的VGG16的时候,你只能获取到最后的分类结果。当我们想获得其中一些模块如第一个卷积层的输出feature该怎么办呢?没错就是用模块钩子把这一模块的输出特征图钩出来!!!

 

如何使用钩子

分为两步:

  1. 定义钩子:其实钩子就是一个函数,只不过函数的形参是已经固定了的,你在函数里面操作这些形参就可以了。
  2. 注册钩子:对需要钩出信息的模块注册一个钩子。

 

钩forward信息流的钩子

        def forward_hook_fn(module, input_, output):
            print("使用普通前向钩子")
            self.activation_maps.append(output)

前向钩子形参有:被钩模块,被钩模块的输入feature map和被钩模块的输出feature map。这三个也是我们可以操作的东西。在张量我们就输出特征图保存到列表activation_mpas中。

 

modules = list(self.model.modules())    # 拿出此模型的所有模块
for module in modules:  # 根据模块的类型进行注册。在这里所有的Relu模块被注册
   if isinstance(module, nn.ReLU):
        module.register_forward_hook(forward_hook_fn)   # 使用register_forward_book进行注册 

 

注册的方式很简单就是调用了一下register_forward_hook.但是要注意的是如何选出你需要的模块。这里是根据类型选择。下一个例子中你会看到根据模块的所在位置来选择模块。

 

钩backward信息流的钩子

        def firstLayer_backward_hook_fn(module, grad_in, grad_out):
            print("使用第一层反向钩子")
            self.image_reconstruction = grad_in[0]

反向钩子形参有:被钩模块,被钩模块的输入梯度与被钩模块的输出梯度。比如该模块是一个全连接层,即y=wx+b。那么grad_in是一个元组其中包含:(bias的梯度,w的梯度,特征图x的梯度)。对该模块是卷积层时grad_in=(特征图x的梯度,w的梯度,bias的梯度)。在这里我们将第一个卷积层的特征图梯度存储到image_reconstruction中去

 

modules = list(self.model.modules())    # 拿出次模型的所有模块
first_layer = modules[1][0] # 拿出第一个卷积层
first_layer.register_backward_hook(firstLayer_backward_hook_fn) # 注册反向钩子

 

为什么是modules[1][0]表示第一个卷积层呢?我们输出一下moduels就知道了!!!

 

pytorch使用笔记(二):模型钩子(Hook for Modules)的使用_第1张图片

 

conclusion

这个钩子在你以后改造网络的时候肯定是个大杀器的哦!!2020-06-19 :)

你可能感兴趣的:(AI,pytorch,深度学习)