pytorch hook机制

pytorch_hook机制

文章目录

    • pytorch_hook机制
      • 前言
      • 一、hook简介
      • 二、Tensor的hook机制
      • 三、基于Module的hook机制
        • register_forward_hook
        • regitster_backward_hook
        • 小结
        • 示例代码
      • 参考文献

前言

在理解hook机制之前,首先应该对pytorch张量的自动求导机制有所了解:PyTorch的自动求导机制详细解析,进而理解正向传播和反向传播的过程中程序在做什么。

在pytorch前向传播的过程中,会动态生成计算图;在反向传播过程中,对计算图中的每个模块的输入输出求解梯度,并把梯队回传到输出。在反向传播过程中为了减少内存消耗,会把过程中产生的梯度删除,仅保留计算图中叶子节点的梯度信息。但是,一些应用要求使用神经网络中间层输入输出的梯度值,如神经网络可解释性的CAM算法等。这时,使用hook机制就能帮助实现这个目标。

一、hook简介

hook机制主要非为两类:基于Tensor的hook机制,以及基于Module的hook机制。我的理解是,基于Tensor方便追踪某个特定张量的梯度,如某层的某个特定的输入;基于Module的hook机制则是在实际中应用比较多,可以帮助获得某层的输入输出的梯度,用于后续的计算。

无论使用哪种类型的hook机制,pytorch都要求我们注册一个hook,我的理解是,使用hook机制相当于一个钩子钩住了网络的前向传播或者反向传播,让用户可以在这中间添加一些操作(也就是调用一个能对前向、后向传播中间信息进行操作的函数)

二、Tensor的hook机制

这里使用一个简单的例子,参考代码见参考链接1。

使用**register_hook(hook)**注册一个钩子,也就是注册一个添加到计算图中间的函数。钩子函数使用的格式为:

hook(grad) -> Tensor or None

该例子使用一个简单的计算图进行计算。

x x x是随机生成一个3*1的tensor;

y = x + 3 y = x + 3 y=x+3

z = m e a n ( s u m ( y ) ) z=mean(sum(\sqrt{y})) z=mean(sum(y ))

import torch
def print_grad(grad):
    print('grad is \n',grad)

x = torch.rand(3,1,requires_grad=True)
print('x value is \n',x)
y = x+3
print('y value is \n',y)
z = torch.mean(torch.pow(y, 1/2))
lr = 1e-3

y.register_hook(print_grad) 
z.backward() # 梯度求解
x.data -= lr*x.grad.data
print('new x is\n',x)

输出:

x value is 
 tensor([[0.5681],
        [0.4868],
        [0.9277]], requires_grad=True)
y value is 
 tensor([[3.5681],
        [3.4868],
        [3.9277]], grad_fn=<AddBackward0>)
grad is 
 tensor([[0.0882],
        [0.0893],
        [0.0841]])
new x is
 tensor([[0.5680],
        [0.4867],
        [0.9276]], requires_grad=True)

在tensor x , y , z x,y,z x,y,z之间的函数关系定义好之后,计算图成功生成,中间变量的值被成功计算出来;在调用backward函数之前首先需要注册hook,否则hook就不会在backward的过程中被执行;最后,在反向传播过程中计算出来 y y y的梯度值并输出。

下面简单说明输出的结果:

输入为: x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3

第一次计算: y i = x i + 3 , i = 1 , 2 , 3 y_i=x_i+3,i=1,2,3 yi=xi+3,i=1,2,3,那么梯度为:KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\̲p̲a̲r̲t̲{y_i}}{\part{x_…

第二次计算: z = m e a n ( ∑ y i ) z=mean(\sum{\sqrt {y_i}}) z=mean(yi ),那么梯度为:KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\̲p̲a̲r̲t̲{z}}{\part{y_i}…

通过梯度传递原则,可以得到KaTeX parse error: Undefined control sequence: \part at position 7: \frac{\̲p̲a̲r̲t̲{z}}{\part{x_i}…

借助上面的推导,可以很快自行验证上述结果的正确性。

三、基于Module的hook机制

与上面的基于tensor的hook机制原理相近,基于Module的hook机制是对模型的模块进行操作,比如神经网络的某个隐藏层。

有两种方法,用于前向传递的hook和用于后向传递的hook;

register_forward_hook

前向传递的hook主要用于在前向传播的过程中钩取模块之间的输入输出,使用**register_forward_hook(hook)**注册前向钩子,其中hook函数是一个如下形式的函数:

hook(module, input, output) -> None or modified output

regitster_backward_hook

后向传递的hook主要用于在后向传播的过程中钩取模块输入输出的梯度信息,使用**register_backward_hook(hook)**注册后向钩子,其中hook函数是一个如下形式的函数:

hook(module, grad_input, grad_output) -> Tensor or None

如果使用register_backward_hook函数目前会报

"warning:

Using a non-full backward hook when the forward contains multiple autograd Nodes"

在某些具有多个自动求导节点的场合,需要使用register_full_backward_hook获得完整的梯度信息(至于两者的区别,因为具体原理笔者还没有很理解,因此不多赘述。但是观察到如果对整个模型注册钩子,两个函数的输出结果不同)

小结

综上,我们只要能够理解register_forward_hook函数、register_backward_hook函数的不同用处即可,前向用于在前向传播过程中钩取模块输入输出值,后向用于在后向传播过程中钩取模块输入输出的梯度结果。

另外,事实上基于tensor的hook机制是类似于后向的hook机制的,hook函数是在后向传递的过程中被调用的。

示例代码

这里参考参考链接2中的代码,为了能够明显区分出前向和后向hook的区别进行了一定的修改,。

import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class MyMean(nn.Module):            # 自定义除法module
    def forward(self, input):
        out = input/4
        return out

def tensor_hook(grad):
    print('tensor hook')
    print('grad:', grad)
    return grad

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.f1 = nn.Linear(4, 1, bias=True) 
        self.f2 = MyMean()
        self.weight_init()

    def forward(self, input):
        self.input = input
        output = self.f1(input) # 先进行运算1,后进行运算2
        output = self.f2(output) 
        return output

    def weight_init(self):
        self.f1.weight.data.fill_(8.0) # 这里设置Linear的权重为8
        self.f1.bias.data.fill_(2.0) # 这里设置Linear的bias为2

    def my_backward_hook(self, module, grad_input, grad_output):
        print('doing my_backward_hook')
        print('original grad:', grad_input)
        print('original outgrad:', grad_output)
        return grad_input
    
    def my_forward_hook(self, module, input, output):
        print('doing my_forward_hook')
        print('input:', input)
        print('output', output)
     

if __name__ == '__main__':
    input = torch.tensor([1, 2, 3, 4], dtype=torch.float32, requires_grad=True).to(device)

    
    net = MyNet()
    net.to(device)
    net.f1.register_forward_hook(net.my_forward_hook) # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
    net.f2.register_forward_hook(net.my_forward_hook)

    net.f1.register_full_backward_hook(net.my_backward_hook) # 这两个hook函数一定要result = net(input)执行前执行,因为hook函数实在forward的时候进行绑定的
    net.f2.register_full_backward_hook(net.my_backward_hook)
    input.register_hook(tensor_hook)
    
    print('forward now')
    result = net(input)
    print('result =', result)
    print('over forward')

    print('\nbackward now')
    result.backward()
    print('over backward')
    
    print('\ninput.grad:', input.grad)
    for param in net.parameters():
        print('{}:grad->{}'.format(param, param.grad))

该网络定义了简单的单层全连接网络,为了方便理解,采用了固定的参数,因此可以很方便地证明代码的运行结果,这里不再赘述。

得到的输出是:

forward now
doing my_forward_hook
input: (tensor([1., 2., 3., 4.], grad_fn=<BackwardHookFunctionBackward>),)
output tensor([82.], grad_fn=<AddBackward0>)
doing my_forward_hook
input: (tensor([82.], grad_fn=<BackwardHookFunctionBackward>),)
output tensor([20.5000], grad_fn=<DivBackward0>)
result = tensor([20.5000], grad_fn=<BackwardHookFunctionBackward>)
over forward

backward now
doing my_backward_hook
original grad: (tensor([0.2500]),)
original outgrad: (tensor([1.]),)
doing my_backward_hook
original grad: (tensor([2., 2., 2., 2.]),)
original outgrad: (tensor([0.2500]),)
tensor hook
grad: tensor([2., 2., 2., 2.])
over backward

input.grad: tensor([2., 2., 2., 2.])
Parameter containing:
tensor([[8., 8., 8., 8.]], requires_grad=True):grad->tensor([[0.2500, 0.5000, 0.7500, 1.0000]])
Parameter containing:
tensor([2.], requires_grad=True):grad->tensor([0.2500])

对于结果值得注意的我认为有两点:

  1. 注意hook函数调用的位置,两者分别在forward和backward的过程中调用;
  2. 注册钩子需要在模型前向传播之前,因为hook是在前向传播的过程中链接上hook的。

同样需要注意,在使用过后,得到的hook可以通过hook.remove()的方法将其移除,减少内存消耗。

参考文献

参考链接1:register_farward_hook

参考链接2:register_backward_hook

你可能感兴趣的:(pytorch学习记录,pytorch,python)