参考链接:https://www.cnblogs.com/hellcat/p/8512090.html
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数包括tensor的hook和nn.Module的hook,用法相似。hook函数在使用后应及时删除,以避免每次都运行钩子增加运行负载。hook函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用hook技术就更合适一些
参考:https://pytorch.org/docs/stable/tensors.html
有如下的register_hook(hook)
方法,为Tensor注册一个backward hook,用来获取变量的梯度。
hook必须遵循如下的格式:hook(grad) -> Tensor or None
,其中grad为获取的梯度
具体的实例如下:
import torch
grad_list = []
def print_grad(grad):
grad = grad * 2
grad_list.append(grad)
x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
h = x.register_hook(print_grad) # double the gradient
out = x.pow(2).sum()
out.backward()
print(grad_list)
'''
[tensor([[ 4., -4.],
[ 4., 4.]])]
'''
# 删除hook函数
h.remove()
有register_forward_hook(hook)
和register_backward_hook(hook)
两种方法,分别对应前向传播和反向传播的hook函数。
在网络执行forward()
之后,执行hook函数,需要具有如下的形式:
hook(module, input, output) -> None or modified output
hook可以修改input和output,但是不会影响forward的结果。最常用的场景是需要提取模型的某一层(不是最后一层)的输出特征,但又不希望修改其原有的模型定义文件,这时就可以利用forward_hook函数。
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
features = []
def hook(module, input, output):
features.append(output.clone().detach())
net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)
print(features[0].size())
handle.remove()
每一次module的inputs的梯度被计算后调用hook,hook必须具有如下的签名:
hook(module, grad_input, grad_output) -> Tensor or None
grad_input
和 grad_output
参数分别表示输入的梯度和输出的梯度,是不能修改的,但是可以通过return一个梯度元组tuple来替代grad_input
。
展示一个实例来解析grad_input
和 grad_output
参数:
import torch
import torch.nn as nn
def hook(module, grad_input, grad_output):
print('grad_input: ', grad_input)
print('grad_output: ', grad_output)
x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_weight: ', module.weight.grad)
handle.remove()
输出:
grad_input: (tensor([1.]), tensor([[ 0.1236, -0.0232, -0.5687]]), tensor([[ 1.],
[ 2.],
[10.]]))
grad_output: (tensor([[1.]]),)
module_weight: tensor([[ 1., 2., 10.]])
可以看出,grad_input元组包含(bias的梯度
,输入x的梯度
,权重weight的梯度
),grad_output元组包含输出y的梯度。
可以在hook函数中通过return来修改grad_input
:
import torch
import torch.nn as nn
def hook(module, grad_input, grad_output):
print('grad_input: ', grad_input)
print('grad_output: ', grad_output)
return grad_input[0] * 0, grad_input[1] * 0, grad_input[2] * 0,
x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_bias: ', module.bias.grad)
print('x: ', x.grad)
print('module_weight: ', module.weight.grad)
handle.remove()
输出:
grad_input: (tensor([1.]), tensor([[ 0.1518, 0.0798, -0.3170]]), tensor([[ 1.],
[ 2.],
[10.]]))
grad_output: (tensor([[1.]]),)
module_bias: tensor([0.])
x: tensor([[0., 0., -0.]])
module_weight: tensor([[0., 0., 0.]])
对于没有参数的Module,比如nn.ReLU
来说,grad_input元组包含(输入x的梯度
),grad_output元组包含(输出y的梯度
)。
def hook(module, grad_input, grad_output):
print('grad_input: ', grad_input)
print('grad_output: ', grad_output)
return (grad_input[0] / 4, )
x = torch.tensor([-1., 2., 10.], requires_grad=True)
module = nn.ReLU()
handle = module.register_backward_hook(hook)
y = module(x).sum()
z = y * y
z.backward()
print(x.grad) # tensor([0., 6., 6.])
handle.remove()
输出:
grad_input: (tensor([ 0., 24., 24.]),)
grad_output: (tensor([24., 24., 24.]),)
tensor([0., 6., 6.])
y = R e L U ( x 1 ) + R e L U ( x 2 ) + R e L U ( x 3 ) y=ReLU(x_{1})+ReLU(x_{2})+ReLU(x_{3}) y=ReLU(x1)+ReLU(x2)+ReLU(x3)
z = y 2 z=y^{2} z=y2
grad_output是传到ReLU模块的输出值的梯度,即 ∂ z ∂ y = 2 y = 24 \frac{\partial z}{\partial y}=2y=24 ∂y∂z=2y=24。
grad_input是进入ReLU模块的输入值的梯度,由 ∂ y ∂ x 1 = 0 , ∂ y ∂ x 2 = 1 , ∂ y ∂ x 3 = 1 \frac{\partial y}{\partial x_{1}}=0,\frac{\partial y}{\partial x_{2}}=1,\frac{\partial y}{\partial x_{3}}=1 ∂x1∂y=0,∂x2∂y=1,∂x3∂y=1,可得:
∂ z ∂ y ∂ y ∂ x 1 = 0 , ∂ z ∂ y ∂ y ∂ x 2 = 24 , ∂ z ∂ y ∂ y ∂ x 3 = 24 \frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{1}}=0,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{2}}=24,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{3}}=24 ∂y∂z∂x1∂y=0,∂y∂z∂x2∂y=24,∂y∂z∂x3∂y=24
在hook函数中可以对输入值 x x x的梯度进行缩放:
[ 0 , 24 , 24 ] / 4 = [ 0 , 6 , 6 ] [0,24,24]/4=[0,6,6] [0,24,24]/4=[0,6,6]