pytorch的学习之路(一)| 模型的forward方法是如何被调用的

问题

out = net(image) # 图像作为输入,经过net做正向传播,得到输出(分类/框/。。。)

你有没有一个疑问,上面这行代码是如何调用forward()函数得到结果的?
我会贴出源码并做解释

解答

一步一步跟踪,net(image)到底经历了什么?(以下引用该开源代码做讲解,其中会做适当简化,以达到说明的目的)

  1. net的定义
net = RetinaFace()
  1. RetinaFace类的定义
class RetinaFace(nn.Module):
	def __init__(self):
		# 定义层结构,举例如下
		self.fpn = FPN()
	def forward(self, inputs):
		out = self.fpn(inputs)
		return out
  1. FPN类的定义
class FPN(nn.Module):
	def __init__(self,in_channels_list,out_channels):
		self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
        
    def forward(self, input):
    	input = list(input.values())
    	
    	output1 = self.output1(input[0])
        output2 = self.output2(input[1])
        output3 = self.output3(input[2])
		
		out = [output1, output2, output3]
		return out
  1. 每一个类都会有一个forward函数,究竟是谁来调用它们的?下面抛出重点:nn.Module的定义(截取有用部分):
class Module(object):
	def forward(self, *input):
	        r"""Defines the computation performed at every call.
	
	        Should be overridden by all subclasses.
	
	        .. note::
	            Although the recipe for forward pass needs to be defined within
	            this function, one should call the :class:`Module` instance afterwards
	            instead of this since the former takes care of running the
	            registered hooks while the latter silently ignores them.
	        """
	        raise NotImplementedError

	def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs) # 重点!!
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result
  • forward是在__call__中调用的,而__call__函数是在类的对象使用‘()’时被调用。此处相当于c++中重载了括号,如果c++不太熟悉也没关系,一般调用在类中定义的函数的方法是:example_class_instance.func(),如果只是使用example_class_instance(),那么这个操作就是在调用__call__这个内置方法
  • 有了上面的知识,那么从第一步来看,out = net(image)实际上就是调用了net的__call__方法,net的__call__方法没有显式定义,那么就使用它的父类方法,也就是调用nn.Module的__call__方法,它调用了forward方法,又有,net类中定义了forward方法,所以使用重写的forward方法
  • 依次类推,out = self.fpn(inputs)也是先调用__call__方法,它进一步调用forward方法,而forward方法被FPN类重写,故调用重写后的forward方法

总结

  1. out = net(input)调用__call__方法
  2. __call__方法中调用forward方法,由于每个网络如上述的RetinaFace和FPN都重写了forward方法,所以,当调用forward时,都调用的是重写之后的版本。至此,回答了上述问题。

你可能感兴趣的:(pytorch的学习之路)