Pytorch forward方法调用原理

   在使用Pytorch自定义网络模型的时候,我们需要继承nn.Module这个类,然后定义forward方法来实现前向转播。如下图的一个自定义的网络模型
Pytorch forward方法调用原理_第1张图片
首先该网络模型的初始化方法__init__需要继承父类nn.Module的初始化方法,用语句super().init()实现。并在初始化方法里面,定义了卷积、BN、激活函数等。接下来定义forward方法,将整个网络连接起来。
   有了上面的定义,我们可以实例化一个对象,例如:
fire2 = Fire(96, 128,16,64,64)
实现前向传播,使用
y= fire2(x)
其中x是该网络的输入,y是输出,实现了forward方法的额功能。这里就会有人感到奇怪,forward作为Fire这个类的方法,使用的时候不应该是
y= fire2.forward(x)
吗。这里为什么一个类的实例可以当做方法直接使用?这是因为这个Fire类继承的父类nn.Module里面定义了__call__方法。一个类如果定义了__call__方法,则该类的实例就可以作为一个方法那样直接使用。例如下列代码[1]

class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()

就会执行print函数,打印其中搞的文字。这里需要区别的是,实例化的时候,类的名称后面括号可以传递参数,例如前面实例化Fire的时候,传递in_channel,out_channel等参数。但是要利用__call__的特性,是在实例名后面的括号中传递参数,例如上面的例子a(),这里虽然没有参数,但是也可以改变__call__的定义使之可以传递参数。
   回到网络模型的内容上来。翻看nn.Module的部分源码[2],可以发现,nn.Module里面果然定义了__call__,并且传递了参数*input。在__call__的定义中国,调用了self.forward。
Pytorch forward方法调用原理_第2张图片
这里其实还有一个点值得注意。其实nn.Module里面并没有定义forward,但他却调用self.forward,严格来说,他是“想要”调用self.forward。如果我们没有定义一个类,例如Fire,来继承nn.Module,并且在这个类里面定义forward,那么nn.Module中__call__下面的self.forward就是无效的。这意味着,父类中__call__下面调用的函数,可以在继承他的子类中定义。下面给出一个简单的例子。

class father():
    def __call__(self):
        self.forward()
        print('I''m the father!')

class child(father):
    def forward(self):
        print('Forward!')
F=father()
C=child()

这里定义了父类father,并定义了继承他的一个子类child。此外还进行了他们的实例化。显然,在father的__call__方法下面,调用了self.forward,但是没有定义。child在继承了father之后,定义了forward。首先,这段代码不会报错,即使father的__call__下面的self.forward并没有定义,这也是前面我说的,虽然没有定义forward,但是可以理解为他“想要”调用self.forward。那么在child记成了father之后,进行了forward的定义,这使得child本身可以调用forward。
   在上面这段代码的基础上,如果我们执行F(),汇报下面这一段错误,这解释了forward没有定义,只是“想要”调用self.forward。
在这里插入图片描述
如果我们执行C(),则如下图输出。显然,在child中补充了forward的定义,就可以成功调用。
在这里插入图片描述

参考文献

[1] JY丫丫,pytorch 中的 forward 的使用与解释,2019-07-25。
[2] 墨氲,pytorch系列 ----暂时就叫5的番外吧: nn.Modlue及nn.Linear 源码理解,2019-10-09。

你可能感兴趣的:(pytorch,python,深度学习)