在理解register_hook之前,首先得搞懂什么叶子节点和非叶子节。简单来说叶子节点是有梯度且独立得张量,例如a = torch.tensor(2.0,requires_grad=True),b= torch.tensor(3.0,requires_grad=True)
,非叶子节点是依赖其他张量而得到得张量如c = a+b
。
判断是叶子节点还是非叶子节点可以使用 is_leaf
来判断一个张量是叶子节点还是非叶子节点。
import torch
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.is_leaf)
print(b.is_leaf)
c = a +b
print(c.is_leaf)
>>> True
>>> True
>>> False
中间张量 c 作为非叶子节点是没有梯度信息得。pytorch默认在梯度反向传播过程中不会记录中间变量梯度信息。而且叶子节点的梯度信息在反向传播流过程中是不允许我们修改的。只能通过print(a.grad)
查看张量的梯度信息。
那么,如果我们想查看中间变量 c 以及想改变叶子节点反向传播过程中的梯度值,应该怎么办呢。这时候就要使用register_hook这个钩子函数了。通过一下两段代码看一下钩子函数的主要作用。
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.grad)
print(b.grad)
c = a*b
print(c.grad) # 由于c是叶子节点,所以他是不记录梯度信息得。前后打印梯度信息都为None
d = torch.tensor(4.0,requires_grad=True)
e = c * d
e.backward()
print(a.grad)
print(b.grad)
print(c.grad)
>>>输出
None
None
None
tensor(12.)
tensor(8.)
None
通过上面代码可以看出,c
作为中间变量在反向传播过程中不记录梯度信息。c=a*b
其中a的梯度就为b
的值,b的梯度就是a的值。接下来对中间变量c 使用register_hook
,这个函数传入的参数得是一个函数。
import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
def c_hook(grad):
print("c_hook",grad)
return grad + 2 # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。
# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad() # 存储中间变量的梯度
print(a.grad)
print(b.grad)
print(c.grad)
c.backward()
print(a.grad)
print(b.grad)
print(c.grad)
>>>
None
None
None
c_hook tensor(1.)
hello my grad is tensor(3.)
tensor(9.)
tensor(6.)
tensor(3.)
为什么输出会是这样的结果呢,一个张量可以注册多个钩子函数,反向传播过程中按照注册的顺序依次运行。 c.register_hook(c_hook) c.register_hook(lambda grad:)
,这两个函数可以重写c的梯度,第一个函数传入的参数是c的梯度,自身对自身的梯度pytorch中默认为1。所以此时c_hook
中传入的grad=1
,这个函数返回值为grad+2=3
,此时会重写中间变量c的梯度信息。第二个钩子函数传入的函数为匿名函数,这个匿名函数对c的梯度没有进行重写,使用的还是上一个钩子函数重写的值,此使打印信息就为3。最后通过c.retain_grad()记c的梯度信息。通过这个例子,我稍微懂了点register_hook这个钩子函数的作用,是不是本来不可修改的梯度信息值,通过这个函数修改了呢。
通过一下这个例子比较再来看一下registe_hook函数的作用。
import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
def c_hook(grad):
print("c_hook",grad)
return grad + 2 # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。
# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad() # 存储中间变量的梯度
d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100) # 将使用100+grad代替本来返回得梯度值
e = c * d
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)
# e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()
e.backward()
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)
>>>输出
None
None
None
None
None
c_hook tensor(8.)
hello my grad is tensor(10.)
tensor(30.)
tensor(20.)
tensor(10.)
tensor(112.)
tensor(2.)
这段代码前部分和前面的代码保持一致,后面添加了e = c * d
,在反向传播前,毋庸置疑a,b,,c,d,e
的梯度都为None
。反向传播过程中首先看 e
,自身对自身的倒数默认为1,但是e
注册的钩子将对原本的梯度 * 2 ,来替代原先的梯度信息,所以打印出的e的梯度信息为2。相应的,e 对 c的梯度信息相应的就变为 2d=8
,e对d的梯度信息就变为 2c=12
,案例说此使d的梯度信息为12,为什么是112呢,可以看出d注册了一个钩子函数,这个钩子给d原本的梯度信息加了100,来代替旧的梯度信息,所以d的梯度信息为112。由于c注册的钩子函数给他加了2,所以c的梯度信息为10。相应的a b 的梯度就都要乘以c 的梯度信息了。 同样,原本不变的梯度信息值在这里都根据register_hook
这个函数相应的被重写。
以上就是我根据视频链接对register_hook的理解。
register_forward_hook register_forward_pre_hook
这个函数主要使用在nn.Module
网络中。
第一个函数看名称是用在网络forward
之前,第二个是运行在forward
之后,举例:
import torch
import torch.nn as nn
class SumNet(nn.Module):
def __init__(self):
super(SumNet, self).__init__()
@staticmethod
def forward(a, b, c):
d = a + b + c
print('forward():')
print(' a:', a)
print(' b:', b)
print(' c:', c)
print()
print(' d:', d)
print()
return d
def forward_pre_hook(module, input_positional_args):
a, b, c = input_positional_args
new_input_positional_args = a + 10, b,c+10
print('forward_pre_hook():')
print(' module:', module)
print(' input_positional_args:', input_positional_args)
print()
print(' new_input_positional_args:', new_input_positional_args)
print()
return new_input_positional_args
def forward_hook(module, input_positional_args, output):
new_output = output + 100
print('forward_hook():')
print(' module:', module)
print(' input_positional_args:', input_positional_args)
print(' output:', output)
print()
print(' new_output:', new_output)
print()
return new_output
def main():
sum_net = SumNet()
sum_net.register_forward_pre_hook(forward_pre_hook)
sum_net.register_forward_hook(forward_hook)
a = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(2.0, requires_grad=True)
c = torch.tensor(3.0, requires_grad=True)
print('start')
print()
print('a:', a)
print('b:', b)
print('c:', c)
print()
print('before model')
print()
d = sum_net(a, b, c) # 前向传播得时候钩子函数起作用了,先是forward_pre_hook,接下来是forward,接下来是forward_hook函数。
print('after model')
print()
print('d:', d)
if __name__ == '__main__':
main()
输出信息:
start
a: tensor(1., requires_grad=True)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)
before model
forward_pre_hook():
module: SumNet()
input_positional_args: (tensor(1., requires_grad=True), tensor(2., requires_grad=True), tensor(3., requires_grad=True))
new_input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))
forward():
a: tensor(11., grad_fn=<AddBackward0>)
b: tensor(2., requires_grad=True)
c: tensor(13., grad_fn=<AddBackward0>)
d: tensor(26., grad_fn=<AddBackward0>)
forward_hook():
module: SumNet()
input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))
output: tensor(26., grad_fn=<AddBackward0>)
new_output: tensor(126., grad_fn=<AddBackward0>)
after model
d: tensor(126., grad_fn=<AddBackward0>)
分析以上为什么会输出这样的结果,前面提到register_forward_hook
这个函数会在网络前向传播前运行,需要两个参数modul 和 input
案例中输入为 tensor 1 2 3
,经过这个函数给2 3 分别加了10,并且返回了一组新的值,这组值是要传入forward
中,可以看出,forward
函数打印的a b c 为传入的这组新值,而不是刚开始定义的1 2 3,forward
函数运行过程中返回每层的输出会运行forward_hook
函数。这个函数主要需要三个参数,module input output
。
以下从Lenet网络来使用这个函数:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
model = LeNet()
# 分别对model的第一个卷积层和最后一层使用了钩子函数,这样既可以取出对应层的输出。
def hook(model,input_,output):
print("最后一层输出:",output.shape)
def conv_hook(model,input_,output):
print("conv1后",input_[0].shape,output.shape)
model.register_forward_hook(hook)
model.conv1.register_forward_hook(conv_hook)
img = torch.randn([1,3,32,32])
out_put = model(img)
>>>
conv1后 torch.Size([1, 3, 32, 32]) torch.Size([1, 6, 28, 28])
最后一层输出: torch.Size([1, 10])
基于上可以看出给不同层使用钩子函数,可以提取出每一层的输出,并进行相应的处理。
以上就是pytorch中register_hook
和register_forward_hook
的基本理解。
如果有问题烦请指出加以改正。