PyTorch学习—16.PyTorch中hook函数

文章目录

      • 一、Hook函数概念
      • 二、四种Hook函数介绍
        • 1. Tensor.register_hook
        • 2. Module.register_forward_hook
        • 3.Module.register_forward_pre_hook
        • 4. Module.register_backward_hook

一、Hook函数概念

  Hook函数机制:不改变主体,实现额外功能,像一个挂件一样将功能挂到函数主体上。Hook函数与PyTorch中的动态图运算机制有关,因为在动态图计算,在运算结束后,中间变量是会被释放掉的,例如:非叶子节点的梯度。但是,我们往往想要提取这些中间变量,这时,我们就可以采用Hook函数在前向传播与反向传播主体上挂上一些额外的功能(函数),通过这些函数获取中间的梯度,甚至是改变中间的梯度。PyTorch一共提供了四种Hook函数:

  • torch.Tensor.register_hook(hook)
  • torch.nn.Module.register_forward_hook
  • torch.nn.Module.register_forward_pre_hook
  • torch.nn.Module.register_backward_hook

一种是针对Tensor,其余三种是针对网络的

二、四种Hook函数介绍

1. Tensor.register_hook

def register_hook(self, hook):
	"""
	接受一个hook函数
	"""
	...

功能:注册一个反向传播hook函数,这是因为张量在反向传播的时候,如果不是叶子节点,它的梯度就会消失。由于反向传播过程中存在数据的释放,所以就有了反向传播的hook函数

  • Hook函数仅一个输入参数,为张量的梯度

    hook(grad) -> Tensor or None
    

  下面,我们通过计算图流程来观察张量梯度的获取以及熟悉Hook函数。
y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)(w+1)
PyTorch学习—16.PyTorch中hook函数_第1张图片

import torch
import torch.nn as nn

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

# 存储张量的梯度
a_grad = list()

def grad_hook(grad):
    """
    定义一个hook函数,将梯度存储到列表中
    :param grad: 梯度
    :return:
    """
    a_grad.append(grad)

# 注册一个反向传播的hook函数,功能是将梯度存储到a_grad列表中
handle = a.register_hook(grad_hook)
# 反向传播
y.backward()

# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])

如果对叶子节点的张量使用hook函数,那么会怎么样呢?

import torch
import torch.nn as nn

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("w.grad: ", w.grad)
handle.remove()
gradient: tensor([30.]) tensor([2.]) None None None
w.grad:  tensor([30.])

与上面比较,发现hook函数相当于对已有张量进行原地操作

2. Module.register_forward_hook

def register_forward_hook(self, hook):
	...

功能:注册module的前向传播hook函数
参数:

  • module: 当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

3.Module.register_forward_pre_hook

功能:注册module前向传播前的hook函数
参数:

  • module: 当前网络层
  • input:当前网络层输入数据

4. Module.register_backward_hook

功能:注册module反向传播的hook函数
参数:

  • module: 当前网络层
  • grad_input:当前网络层输入梯度数据
  • grad_output:当前网络层输出梯度数据

下面例子展示这三个hook函数
PyTorch学习—16.PyTorch中hook函数_第2张图片

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

def forward_hook(module, data_input, data_output):
    """
    定义前向传播hook函数
    :param module:网络
    :param data_input:输入数据
    :param data_output:输出数据
    """
    fmap_block.append(data_output)
    input_block.append(data_input)

def forward_pre_hook(module, data_input):
    """
    定义前向传播前的hook函数
    :param module: 网络
    :param data_input: 输入数据
    :return:
    """
    print("forward_pre_hook input:{}".format(data_input))

def backward_hook(module, grad_input, grad_output):
    """
    定义反向传播的hook函数
    :param module: 网络
    :param grad_input: 输入梯度
    :param grad_output: 输出梯度
    :return:
    """
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))

# 初始化网络
net = Net()
# 第一个卷积核全设置为1
net.conv1.weight[0].detach().fill_(1)
# 第二个卷积核全设置为2
net.conv1.weight[1].detach().fill_(2)
# bias不考虑
net.conv1.bias.data.detach().zero_()

# 注册hook
fmap_block = list()
input_block = list()
# 给卷积层注册前向传播hook函数
net.conv1.register_forward_hook(forward_hook)
# 给卷积层注册前向传播前的hook函数
net.conv1.register_forward_pre_hook(forward_pre_hook)
# 给卷积层注册反向传播的hook函数
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],
        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],
         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],
         [[18., 18.],
          [18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
在这里插入图片描述


你可能感兴趣的:(PyTorch框架学习,PyTorch,hook函数)