为什么pytorch在定义模型和损失函数时能直接输入数据

想到这个问题主要在于有个朋友问我为什么定义MSE时,损失函数定义为loss = torch.nn.MSELoss(),使用时却可以直接输入数数据loss(x,y),这个问题和定义网络的时候一样,在定义的时候只需要写网络层是什么。,而使用时只需要输入x
为什么pytorch在定义模型和损失函数时能直接输入数据_第1张图片
很显然,定义是__init__干的活,而计算则是__forward__干的活,但问题就在于为什么网络能够自动唤起forward而不需要调用,我们类比损失函数和自定义损失函数:
通过查找函数,发现Pytorch给定的损失函数,如MSE等都继承自class _Loss(Module):,而该类又继承自Moudle,对于自定义损失函数而言:为什么pytorch在定义模型和损失函数时能直接输入数据_第2张图片
依然需要继承自Module类,很显然,问题的源头就在于Module类,通过向上查找Module,可以发现一句__call__ : Callable[..., Any] = _call_impl,在这里通过_call_impl使用了__call__函数,且有result = self.forward(*input, **kwargs)
__call__函数的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。
因此,通过使用__call__函数Pytorch能自动调用forward方法

参考:
https://zhuanlan.zhihu.com/p/366461413

http://c.biancheng.net/view/2380.html

你可能感兴趣的:(pytorch,深度学习,人工智能,神经网络)