PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

您的位置 首页 PyTorch 学习笔记系列

PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解

PyTorch入门实战教程

在看pytorch官方文档的时候,发现在nn.Module部分和Variable部分均有hook的身影。感到很神奇,因为在使用tensorflow的时候没有碰到过这个词。所以打算一探究竟。

文章目录 [隐藏]

  • 1 Variable 的 hook
    • 1.1 register_hook(hook)
  • 2 nn.Module的hook
    • 2.1 register_forward_hook(hook)
  • 3 register_backward_hook

Variable 的 hook

register_hook(hook)

注册一个backward钩子。

每次gradients被计算的时候,这个hook都被调用。hook应该拥有以下签名:

hook不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度。

这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。

例子:

输出:

nn.Module的hook

register_forward_hook(hook)

在module上注册一个forward hook。

这里要注意的是,hook 只能注册到 Module 上,即,仅仅是简单的 op 包装的 Module,而不是我们继承 Module时写的那个类,我们继承 Module写的类叫做 Container。

每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:

hook不应该修改 input和output的值。 这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hook从module移除。

看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。

先看 register_forward_hook

这个方法的作用是在此module上注册一个hook,函数中第一句就没必要在意了,主要看第二句,是把注册的hook保存在_forward_hooks字典里。

再看 nn.Module 的__call__方法(被阉割了,只留下需要关注的部分):