forward()
是怎么被调用的
看了一个源码,从最开始看到最后就看到了每个函数里面都有一个def forward()
方法。但是,没有看到调用的地方,甚至是参数,和方法的参数都不一样了。
那代码怎么看呐?从网上看了好多的方法,总是差一点,没有那么明显,就要到了重要的地方了,就结束了。
写博客的时候你会发现,段首没有办法缩进;
如果使用两个Tab键/四个空格键
的话,就成了下面这样:
1111111111111
2222222
解决方法:使用特殊占位符
,不同占位符所占空白是不一样大的。
  or   表示一个半角的空格
  or   表示一个全角的空格
   两个全角的空格(用的比较多)
or   不断行的空白格
效果展示:
表示一个半角的空格
表示一个全角的空格
两个全角的空格(用的比较多)
不断行的空白格
Python类
的执行顺序class Foo():
def __init__(self, x):
print("this is class of Foo.")
print("Foo类属性初始化")
self.f = "foo"
print(x)
class Children(Foo):
def __init__(self):
y = 1
print("this is class of Children.")
print("Children类属性初始化")
super(Children, self).__init__(y)
# 进入`Foo类`中,向`Foo类`中传入参数`y`,同时初始化`Foo类属性`。
a = Children()
print(***)
print(a)
print(***)
print(a.f)
输出结果:
***
this is class of Children.
Children类属性初始化
this is class of Foo.
Foo类属性初始化
1
***
foo
执行顺序:
创建实例化对象:a = Children()
执行
print(a)
–>进入Childern类
–>初始化Childern类
参数,执行def __init__(self):
函数 -->进入Children父类Foo
,传入参数y
并初始化父类Foo
参数super(Children, self).__init__(y)
,执行Foo
中的参数初始化。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
#1个输入图像通道,6个输出通道,3x3平方卷积核
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
输出结果:
>>>Net(
(conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=576, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
那关于
forward是怎么被调用的
,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数
就可以自动调用 forward 函数
。
# list 1
class A():
def __call__(self):
print('i can be called like a function')
a = A()
a()
输出结果:
>>> i can be called like a function
# list 2
class B():
def __init__(self):
print('i can be called like a function')
a = B()
print(a)
输出结果:
>>>i can be called like a function
>>><__main__.B at 0x5978fd0>
# list 3
class C():
def __init__(self):
print('a')
def __call__(self):
print('b')
a = C()
a()
输出结果:
>>>a
>>>b
# list 4
class D():
def __init__(self, init_age):
print('我年龄是:',init_age)
self.age = init_age
def __call__(self, added_age):
res = self.forward(added_age)
return res
def forward(self, input_):
print('forward 函数被调用了')
return input_ + self.age
print('对象初始化。。。。')
a = D(10)
print("*************split line*************")
input_param = a(2)
print("我现在的年龄是:", input_param)
输出结果:
>>>对象初始化。。。。
我年龄是: 10
*************split line*************
forward 函数被调用了
我现在的年龄是: 12
pytorch主要也是就是按照__call__, init,forward三个函数实现网络层之间的架构的,所以从以上例子可看出,定义__call__方法的类可以当作函数调用,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。
super
继承,也需要了解一下# list 1
class Net(nn.Module):
def __init__(self):
print("this is Net")
self.a = 1
super(Net, self).__init__()
def forward(self, x):
print("this is forward of Net", x)
class Children(Net):
def __init__(self):
print("this is children")
super(Children, self).__init__()
a = Children()
print("*************split line*************")
a(1)
输出结果:
this is children
this is Net
*************split line*************
this is forward of Net 1
上面代码执行顺序:
创建
实例化对象
:a = Children()
执行a
:进入Childern类
–>初始化Childern类参数
,执行def __init__(self):
-->进入Children父类Net
, 并初始化父类Net
参数super(Children, self).__init__()
,执行Net
中的参数初始化 -->传入参数
并执行父类Net的forward()函数
。
nn.Module
源码部分class FPN(nn.Module):
def __init__(self,in_channels_list,out_channels):
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1)
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1)
def forward(self, input):
input = list(input.values())
output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])
out = [output1, output2, output3]
return out
class RetinaFace(nn.Module):
def __init__(self):
# 定义层结构,举例如下
self.fpn = FPN()
def forward(self, inputs):
out = self.fpn(inputs)
return out
net = RetinaFace()
out = net(image) # 图像作为输入,经过net做正向传播,得到输出(分类/框/。。。)
看一下nn.Module
的定义(截取有用部分):
class Module(object):
def forward(self, *input):
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs) # 重点!!
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
(如:此例的net(image))
。一般调用在类中定义的函数的方法是:example_class_instance.func(),如果只是使用example_class_instance(),那么这个操作就是在调用__call__这个内置方法out = net(image)
实际上就是调用了net
的__call__方法
,net
的__call__方法
没有显式定义,那么就使用它的父类方法
,也就是调用nn.Module的__call__方法
,它调用了forward方法
,又有,net类
中定义了forward方法
,所以使用重写的forward方法
。out = self.fpn(inputs)
也是先调用__call__方法
,它进一步调用forward方法
,而forward方法
被FPN类重写
,故调用重写后的forward方法
。定义模型:
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
# ......
def forward(self, x):
# ......
return x
data = ..... #输入数据
# 实例化一个对象
module = Module()
# 前向传播 直接把输入传入实列化
module(data)
#没有使用module.forward(data)
#实际上module(data) 等价于module.forward(data)
等价的原因是因为
python calss
中的__call__
可以让类像函数一样调用,当执行model(x)
的时候,底层自动调用forward方法
计算结果。
在__call__ 里可调用其它的函数:
class A():
def __call__(self, param):
print('我在__call__中,传入参数',param)
res = self.forward(param)
return res
def forward(self, x):
print('我在forward函数中,传入参数类型是值为: ',x)
return x
a = A()
y = a('i')
print("*****")
print("传入的参数是:", y)
输出结果:
>>> 我在__call__中,传入参数 i
>>>我在forward函数中,传入参数类型是值为: i
>>>*****
>>>传入的参数是: i
nn.Module 的__call__方法部分源码如下所示:
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
#将注册的hook拿出来用
hook_result = hook(self, input, result)
...
return result
可以看到,当执行model(x)的时候,底层自动调用forward方法计算结果。
class Animal():
def __init__(self,name):
self.name = name
def greet(self):
print('animal is %s' % self.name)
class Dog(Animal):
def greet(self):
super(Dog, self).greet()
print('wangwang')
a=Dog('dog')
a.greet()
输出结果:
>>>animal is dog
>>>wangwang
那么调用forward方法的具体流程是什么样的呢?具体流程是这样的:
以一个Module为例:
1. 调用Module的call方法
2. Module的call里面调用Module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. Module的forward返回值
8. 在Module的call进行forward_hook操作,然后返回值。
(1)像__ getitem__这种由两个双下划线构成的方法,被称为魔术方法。
(2)魔术方法是为了给python解释器用的。当使用len(collection)时,实际上调用的就是collection.__ len__方法。而在使用obj[key]的形式来访问元素时,实际上调用的是object.__ getitem__(key)方法。
(3)魔术方法是属于类的方法,也就是说不需要实例化类就可以访问到该方法,同时,实例化的对象都可以访问到该方法。
(4)使用__ getitem __ 和 __ len __方法,我们就可以实现一个对自定义数据类型的迭代和访问。
class Fun:
def __init__(self, x_list):
""" initialize the class instance
Args:
x_list: data with list type
Returns:
None
"""
if not isinstance(x_list, list):
raise ValueError("input x_list is not a list type")
self.data = x_list
print("intialize success")
def __getitem__(self, idx):
print("__getitem__ is called")
return self.data[idx]
def __len__(self):
print("__len__ is called")
return len(self.data)
fun = Fun(x_list=[1, 2, 3, 4, 5]) # intialize success
print(fun[2]) # 索引,调用的是第二个def
# __getitem__ is called
# 3
print(len(fun)) # 调用的是第三个def
# __len__ is called
# 5