Hook函数机制:不改变主体,实现额外功能,像一个挂件一样将功能挂到函数主体上。Hook函数与PyTorch中的动态图运算机制有关,因为在动态图计算,在运算结束后,中间变量是会被释放掉的,例如:非叶子节点的梯度。但是,我们往往想要提取这些中间变量,这时,我们就可以采用Hook函数在前向传播与反向传播主体上挂上一些额外的功能(函数),通过这些函数获取中间的梯度,甚至是改变中间的梯度。PyTorch一共提供了四种Hook函数:
一种是针对Tensor,其余三种是针对网络的
def register_hook(self, hook):
"""
接受一个hook函数
"""
...
功能:注册一个反向传播hook函数,这是因为张量在反向传播的时候,如果不是叶子节点,它的梯度就会消失。由于反向传播过程中存在数据的释放,所以就有了反向传播的hook函数
Hook函数仅一个输入参数,为张量的梯度
hook(grad) -> Tensor or None
下面,我们通过计算图流程来观察张量梯度的获取以及熟悉Hook函数。
y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)∗(w+1)
import torch
import torch.nn as nn
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 存储张量的梯度
a_grad = list()
def grad_hook(grad):
"""
定义一个hook函数,将梯度存储到列表中
:param grad: 梯度
:return:
"""
a_grad.append(grad)
# 注册一个反向传播的hook函数,功能是将梯度存储到a_grad列表中
handle = a.register_hook(grad_hook)
# 反向传播
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])
如果对叶子节点的张量使用hook函数,那么会怎么样呢?
import torch
import torch.nn as nn
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
grad *= 2
return grad*3
handle = w.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("w.grad: ", w.grad)
handle.remove()
gradient: tensor([30.]) tensor([2.]) None None None
w.grad: tensor([30.])
与上面比较,发现hook函数相当于对已有张量进行原地操作
def register_forward_hook(self, hook):
...
功能:注册module的前向传播hook函数
参数:
功能:注册module前向传播前的hook函数
参数:
功能:注册module反向传播的hook函数
参数:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
"""
定义前向传播hook函数
:param module:网络
:param data_input:输入数据
:param data_output:输出数据
"""
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
"""
定义前向传播前的hook函数
:param module: 网络
:param data_input: 输入数据
:return:
"""
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
"""
定义反向传播的hook函数
:param module: 网络
:param grad_input: 输入梯度
:param grad_output: 输出梯度
:return:
"""
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
# 第一个卷积核全设置为1
net.conv1.weight[0].detach().fill_(1)
# 第二个卷积核全设置为2
net.conv1.weight[1].detach().fill_(2)
# bias不考虑
net.conv1.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
# 给卷积层注册前向传播hook函数
net.conv1.register_forward_hook(forward_hook)
# 给卷积层注册前向传播前的hook函数
net.conv1.register_forward_pre_hook(forward_pre_hook)
# 给卷积层注册反向传播的hook函数
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]],
[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
[0.0000, 0.0000]],
[[0.5000, 0.0000],
[0.0000, 0.0000]]]]),)
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9., 9.],
[ 9., 9.]],
[[18., 18.],
[18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!